Apply black and isort auto formatting

Signed-off-by: Marek Pikuła <marek.pikula@embevity.com>
This commit is contained in:
Marek Pikuła
2023-04-03 12:11:45 +00:00
parent 4919147483
commit 358c0086af
4 changed files with 1625 additions and 859 deletions

View File

@@ -1,27 +1,36 @@
# pylint: disable=wrong-import-order import json
import logging
import os
from datetime import date, timedelta
import requests, json, os, logging, yaml import requests
import yaml
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from datetime import timedelta, date
from dateutil import parser from dateutil import parser
from flask import Flask from flask import Flask
LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', '').upper() LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', "").upper()
# Initiate the Flask application and logging: # Initiate the Flask application and logging:
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
match LOG_LEVEL: match LOG_LEVEL:
case "DEBUG" : app.logger.setLevel(logging.DEBUG) case "DEBUG":
case "INFO" : app.logger.setLevel(logging.INFO) app.logger.setLevel(logging.DEBUG)
case "WARNING" : app.logger.setLevel(logging.WARNING) case "INFO":
case "ERROR" : app.logger.setLevel(logging.ERROR) app.logger.setLevel(logging.INFO)
case "CRITICAL": app.logger.setLevel(logging.CRITICAL) case "WARNING":
app.logger.setLevel(logging.WARNING)
case "ERROR":
app.logger.setLevel(logging.ERROR)
case "CRITICAL":
app.logger.setLevel(logging.CRITICAL)
################################################################## ##################################################################
# Functions related to HEADSCALE and API KEYS # Functions related to HEADSCALE and API KEYS
################################################################## ##################################################################
def get_url(inpage=False): def get_url(inpage=False):
if not inpage: if not inpage:
return os.environ['HS_SERVER'] return os.environ["HS_SERVER"]
config_file = "" config_file = ""
try: try:
config_file = open("/etc/headscale/config.yml", "r") config_file = open("/etc/headscale/config.yml", "r")
@@ -32,12 +41,15 @@ def get_url(inpage=False):
config_yaml = yaml.safe_load(config_file) config_yaml = yaml.safe_load(config_file)
if "server_url" in config_yaml: if "server_url" in config_yaml:
return str(config_yaml["server_url"]) return str(config_yaml["server_url"])
app.logger.warning("Failed to find server_url in the config. Falling back to ENV variable") app.logger.warning(
return os.environ['HS_SERVER'] "Failed to find server_url in the config. Falling back to ENV variable"
)
return os.environ["HS_SERVER"]
def set_api_key(api_key): def set_api_key(api_key):
# User-set encryption key # User-set encryption key
encryption_key = os.environ['KEY'] encryption_key = os.environ["KEY"]
# Key file on the filesystem for persistent storage # Key file on the filesystem for persistent storage
key_file = open("/data/key.txt", "wb+") key_file = open("/data/key.txt", "wb+")
# Preparing the Fernet class with the key # Preparing the Fernet class with the key
@@ -47,15 +59,18 @@ def set_api_key(api_key):
# Return true if the file wrote correctly # Return true if the file wrote correctly
return True if key_file.write(encrypted_key) else False return True if key_file.write(encrypted_key) else False
def get_api_key(): def get_api_key():
if not os.path.exists("/data/key.txt"): return False if not os.path.exists("/data/key.txt"):
return False
# User-set encryption key # User-set encryption key
encryption_key = os.environ['KEY'] encryption_key = os.environ["KEY"]
# Key file on the filesystem for persistent storage # Key file on the filesystem for persistent storage
key_file = open("/data/key.txt", "rb+") key_file = open("/data/key.txt", "rb+")
# The encrypted key read from the file # The encrypted key read from the file
enc_api_key = key_file.read() enc_api_key = key_file.read()
if enc_api_key == b'': return "NULL" if enc_api_key == b"":
return "NULL"
# Preparing the Fernet class with the key # Preparing the Fernet class with the key
fernet = Fernet(encryption_key) fernet = Fernet(encryption_key)
@@ -64,33 +79,38 @@ def get_api_key():
return decrypted_key return decrypted_key
def test_api_key(url, api_key): def test_api_key(url, api_key):
response = requests.get( response = requests.get(
str(url) + "/api/v1/apikey", str(url) + "/api/v1/apikey",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.status_code return response.status_code
# Expires an API key # Expires an API key
def expire_key(url, api_key): def expire_key(url, api_key):
payload = {'prefix':str(api_key[0:10])} payload = {"prefix": str(api_key[0:10])}
json_payload = json.dumps(payload) json_payload = json.dumps(payload)
app.logger.debug("Sending the payload '"+str(json_payload)+"' to the headscale server") app.logger.debug(
"Sending the payload '" + str(json_payload) + "' to the headscale server"
)
response = requests.post( response = requests.post(
str(url) + "/api/v1/apikey/expire", str(url) + "/api/v1/apikey/expire",
data=json_payload, data=json_payload,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.status_code return response.status_code
# Checks if the key needs to be renewed # Checks if the key needs to be renewed
# If it does, renews the key, then expires the old key # If it does, renews the key, then expires the old key
def renew_api_key(url, api_key): def renew_api_key(url, api_key):
@@ -101,7 +121,13 @@ def renew_api_key(url, api_key):
expiration_time = key_info["expiration"] expiration_time = key_info["expiration"]
today_date = date.today() today_date = date.today()
expire = parser.parse(expiration_time) expire = parser.parse(expiration_time)
expire_fmt = str(expire.year) + "-" + str(expire.month).zfill(2) + "-" + str(expire.day).zfill(2) expire_fmt = (
str(expire.year)
+ "-"
+ str(expire.month).zfill(2)
+ "-"
+ str(expire.day).zfill(2)
)
expire_date = date.fromisoformat(expire_fmt) expire_date = date.fromisoformat(expire_fmt)
delta = expire_date - today_date delta = expire_date - today_date
tmp = today_date + timedelta(days=90) tmp = today_date + timedelta(days=90)
@@ -110,18 +136,20 @@ def renew_api_key(url, api_key):
# If the delta is less than 5 days, renew the key: # If the delta is less than 5 days, renew the key:
if delta < timedelta(days=5): if delta < timedelta(days=5):
app.logger.warning("Key is about to expire. Delta is " + str(delta)) app.logger.warning("Key is about to expire. Delta is " + str(delta))
payload = {'expiration':str(new_expiration_date)} payload = {"expiration": str(new_expiration_date)}
json_payload = json.dumps(payload) json_payload = json.dumps(payload)
app.logger.debug("Sending the payload '"+str(json_payload)+"' to the headscale server") app.logger.debug(
"Sending the payload '" + str(json_payload) + "' to the headscale server"
)
response = requests.post( response = requests.post(
str(url) + "/api/v1/apikey", str(url) + "/api/v1/apikey",
data=json_payload, data=json_payload,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
new_key = response.json() new_key = response.json()
app.logger.debug("JSON: " + json.dumps(new_key)) app.logger.debug("JSON: " + json.dumps(new_key))
@@ -140,7 +168,9 @@ def renew_api_key(url, api_key):
else: else:
app.logger.error("Testing the API key failed.") app.logger.error("Testing the API key failed.")
return False # The API Key test failed return False # The API Key test failed
else: return True # No work is required else:
return True # No work is required
# Gets information about the current API key # Gets information about the current API key
def get_api_key_info(url, api_key): def get_api_key_info(url, api_key):
@@ -148,9 +178,9 @@ def get_api_key_info(url, api_key):
response = requests.get( response = requests.get(
str(url) + "/api/v1/apikey", str(url) + "/api/v1/apikey",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
json_response = response.json() json_response = response.json()
# Find the current key in the array: # Find the current key in the array:
@@ -163,19 +193,25 @@ def get_api_key_info(url, api_key):
app.logger.error("Could not find a valid key in Headscale. Need a new API key.") app.logger.error("Could not find a valid key in Headscale. Need a new API key.")
return "Key not found" return "Key not found"
################################################################## ##################################################################
# Functions related to MACHINES # Functions related to MACHINES
################################################################## ##################################################################
# register a new machine # register a new machine
def register_machine(url, api_key, machine_key, user): def register_machine(url, api_key, machine_key, user):
app.logger.info("Registering machine %s to user %s", str(machine_key), str(user)) app.logger.info("Registering machine %s to user %s", str(machine_key), str(user))
response = requests.post( response = requests.post(
str(url)+"/api/v1/machine/register?user="+str(user)+"&key="+str(machine_key), str(url)
+ "/api/v1/machine/register?user="
+ str(user)
+ "&key="
+ str(machine_key),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
@@ -187,25 +223,27 @@ def set_machine_tags(url, api_key, machine_id, tags_list):
str(url) + "/api/v1/machine/" + str(machine_id) + "/tags", str(url) + "/api/v1/machine/" + str(machine_id) + "/tags",
data=tags_list, data=tags_list,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Moves machine_id to user "new_user" # Moves machine_id to user "new_user"
def move_user(url, api_key, machine_id, new_user): def move_user(url, api_key, machine_id, new_user):
app.logger.info("Moving machine_id %s to user %s", str(machine_id), str(new_user)) app.logger.info("Moving machine_id %s to user %s", str(machine_id), str(new_user))
response = requests.post( response = requests.post(
str(url) + "/api/v1/machine/" + str(machine_id) + "/user?user=" + str(new_user), str(url) + "/api/v1/machine/" + str(machine_id) + "/user?user=" + str(new_user),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
def update_route(url, api_key, route_id, current_state): def update_route(url, api_key, route_id, current_state):
action = "disable" if current_state == "True" else "enable" action = "disable" if current_state == "True" else "enable"
@@ -220,45 +258,48 @@ def update_route(url, api_key, route_id, current_state):
response = requests.post( response = requests.post(
str(url) + "/api/v1/routes/" + str(route_id) + "/" + str(action), str(url) + "/api/v1/routes/" + str(route_id) + "/" + str(action),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Get all machines on the Headscale network # Get all machines on the Headscale network
def get_machines(url, api_key): def get_machines(url, api_key):
app.logger.info("Getting machine information") app.logger.info("Getting machine information")
response = requests.get( response = requests.get(
str(url) + "/api/v1/machine", str(url) + "/api/v1/machine",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Get machine with "machine_id" on the Headscale network # Get machine with "machine_id" on the Headscale network
def get_machine_info(url, api_key, machine_id): def get_machine_info(url, api_key, machine_id):
app.logger.info("Getting information for machine ID %s", str(machine_id)) app.logger.info("Getting information for machine ID %s", str(machine_id))
response = requests.get( response = requests.get(
str(url) + "/api/v1/machine/" + str(machine_id), str(url) + "/api/v1/machine/" + str(machine_id),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Delete a machine from Headscale # Delete a machine from Headscale
def delete_machine(url, api_key, machine_id): def delete_machine(url, api_key, machine_id):
app.logger.info("Deleting machine %s", str(machine_id)) app.logger.info("Deleting machine %s", str(machine_id))
response = requests.delete( response = requests.delete(
str(url) + "/api/v1/machine/" + str(machine_id), str(url) + "/api/v1/machine/" + str(machine_id),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -267,15 +308,16 @@ def delete_machine(url, api_key, machine_id):
app.logger.error("Deleting machine failed! %s", str(response.json())) app.logger.error("Deleting machine failed! %s", str(response.json()))
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Rename "machine_id" with name "new_name" # Rename "machine_id" with name "new_name"
def rename_machine(url, api_key, machine_id, new_name): def rename_machine(url, api_key, machine_id, new_name):
app.logger.info("Renaming machine %s", str(machine_id)) app.logger.info("Renaming machine %s", str(machine_id))
response = requests.post( response = requests.post(
str(url) + "/api/v1/machine/" + str(machine_id) + "/rename/" + str(new_name), str(url) + "/api/v1/machine/" + str(machine_id) + "/rename/" + str(new_name),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -284,15 +326,16 @@ def rename_machine(url, api_key, machine_id, new_name):
app.logger.error("Machine rename failed! %s", str(response.json())) app.logger.error("Machine rename failed! %s", str(response.json()))
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Gets routes for the passed machine_id # Gets routes for the passed machine_id
def get_machine_routes(url, api_key, machine_id): def get_machine_routes(url, api_key, machine_id):
app.logger.info("Getting routes for machine %s", str(machine_id)) app.logger.info("Getting routes for machine %s", str(machine_id))
response = requests.get( response = requests.get(
str(url) + "/api/v1/machine/" + str(machine_id) + "/routes", str(url) + "/api/v1/machine/" + str(machine_id) + "/routes",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
if response.status_code == 200: if response.status_code == 200:
app.logger.info("Routes obtained") app.logger.info("Routes obtained")
@@ -300,42 +343,47 @@ def get_machine_routes(url, api_key, machine_id):
app.logger.error("Failed to get routes: %s", str(response.json())) app.logger.error("Failed to get routes: %s", str(response.json()))
return response.json() return response.json()
# Gets routes for the entire tailnet # Gets routes for the entire tailnet
def get_routes(url, api_key): def get_routes(url, api_key):
app.logger.info("Getting routes") app.logger.info("Getting routes")
response = requests.get( response = requests.get(
str(url) + "/api/v1/routes", str(url) + "/api/v1/routes",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
################################################################## ##################################################################
# Functions related to USERS # Functions related to USERS
################################################################## ##################################################################
# Get all users in use # Get all users in use
def get_users(url, api_key): def get_users(url, api_key):
app.logger.info("Getting Users") app.logger.info("Getting Users")
response = requests.get( response = requests.get(
str(url) + "/api/v1/user", str(url) + "/api/v1/user",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Rename "old_name" with name "new_name" # Rename "old_name" with name "new_name"
def rename_user(url, api_key, old_name, new_name): def rename_user(url, api_key, old_name, new_name):
app.logger.info("Renaming user %s to %s.", str(old_name), str(new_name)) app.logger.info("Renaming user %s to %s.", str(old_name), str(new_name))
response = requests.post( response = requests.post(
str(url) + "/api/v1/user/" + str(old_name) + "/rename/" + str(new_name), str(url) + "/api/v1/user/" + str(old_name) + "/rename/" + str(new_name),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -344,15 +392,16 @@ def rename_user(url, api_key, old_name, new_name):
app.logger.error("Renaming User failed!") app.logger.error("Renaming User failed!")
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Delete a user from Headscale # Delete a user from Headscale
def delete_user(url, api_key, user_name): def delete_user(url, api_key, user_name):
app.logger.info("Deleting a User: %s", str(user_name)) app.logger.info("Deleting a User: %s", str(user_name))
response = requests.delete( response = requests.delete(
str(url) + "/api/v1/user/" + str(user_name), str(url) + "/api/v1/user/" + str(user_name),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -361,6 +410,7 @@ def delete_user(url, api_key, user_name):
app.logger.error("Deleting User failed!") app.logger.error("Deleting User failed!")
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Add a user from Headscale # Add a user from Headscale
def add_user(url, api_key, data): def add_user(url, api_key, data):
app.logger.info("Adding user: %s", str(data)) app.logger.info("Adding user: %s", str(data))
@@ -368,10 +418,10 @@ def add_user(url, api_key, data):
str(url) + "/api/v1/user", str(url) + "/api/v1/user",
data=data, data=data,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -380,22 +430,25 @@ def add_user(url, api_key, data):
app.logger.error("Adding User failed!") app.logger.error("Adding User failed!")
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
################################################################## ##################################################################
# Functions related to PREAUTH KEYS in USERS # Functions related to PREAUTH KEYS in USERS
################################################################## ##################################################################
# Get all PreAuth keys associated with a user "user_name" # Get all PreAuth keys associated with a user "user_name"
def get_preauth_keys(url, api_key, user_name): def get_preauth_keys(url, api_key, user_name):
app.logger.info("Getting PreAuth Keys in User %s", str(user_name)) app.logger.info("Getting PreAuth Keys in User %s", str(user_name))
response = requests.get( response = requests.get(
str(url) + "/api/v1/preauthkey?user=" + str(user_name), str(url) + "/api/v1/preauthkey?user=" + str(user_name),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Add a preauth key to the user "user_name" given the booleans "ephemeral" # Add a preauth key to the user "user_name" given the booleans "ephemeral"
# and "reusable" with the expiration date "date" contained in the JSON payload "data" # and "reusable" with the expiration date "date" contained in the JSON payload "data"
def add_preauth_key(url, api_key, data): def add_preauth_key(url, api_key, data):
@@ -404,10 +457,10 @@ def add_preauth_key(url, api_key, data):
str(url) + "/api/v1/preauthkey", str(url) + "/api/v1/preauthkey",
data=data, data=data,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -416,6 +469,7 @@ def add_preauth_key(url, api_key, data):
app.logger.error("Adding PreAuth Key failed!") app.logger.error("Adding PreAuth Key failed!")
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Expire a pre-auth key. data is {"user": "string", "key": "string"} # Expire a pre-auth key. data is {"user": "string", "key": "string"}
def expire_preauth_key(url, api_key, data): def expire_preauth_key(url, api_key, data):
app.logger.info("Expiring PreAuth Key...") app.logger.info("Expiring PreAuth Key...")
@@ -423,10 +477,10 @@ def expire_preauth_key(url, api_key, data):
str(url) + "/api/v1/preauthkey/expire", str(url) + "/api/v1/preauthkey/expire",
data=data, data=data,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
app.logger.debug("expire_preauth_key - Return: " + str(response.json())) app.logger.debug("expire_preauth_key - Return: " + str(response.json()))

187
helper.py
View File

@@ -1,57 +1,106 @@
# pylint: disable=wrong-import-order import logging
import os
import os, headscale, requests, logging import requests
from flask import Flask from flask import Flask
LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', '').upper() import headscale
LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', "").upper()
# Initiate the Flask application and logging: # Initiate the Flask application and logging:
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
match LOG_LEVEL: match LOG_LEVEL:
case "DEBUG" : app.logger.setLevel(logging.DEBUG) case "DEBUG":
case "INFO" : app.logger.setLevel(logging.INFO) app.logger.setLevel(logging.DEBUG)
case "WARNING" : app.logger.setLevel(logging.WARNING) case "INFO":
case "ERROR" : app.logger.setLevel(logging.ERROR) app.logger.setLevel(logging.INFO)
case "CRITICAL": app.logger.setLevel(logging.CRITICAL) case "WARNING":
app.logger.setLevel(logging.WARNING)
case "ERROR":
app.logger.setLevel(logging.ERROR)
case "CRITICAL":
app.logger.setLevel(logging.CRITICAL)
def pretty_print_duration(duration, delta_type=""): def pretty_print_duration(duration, delta_type=""):
"""Prints a duration in human-readable formats""" """Prints a duration in human-readable formats"""
days, seconds = duration.days, duration.seconds days, seconds = duration.days, duration.seconds
hours = (days * 24 + seconds // 3600) hours = days * 24 + seconds // 3600
mins = (seconds % 3600) // 60 mins = (seconds % 3600) // 60
secs = seconds % 60 secs = seconds % 60
if delta_type == "expiry": if delta_type == "expiry":
if days > 730: return "in greater than two years" if days > 730:
if days > 365: return "in greater than a year" return "in greater than two years"
if days > 0 : return "in "+ str(days ) + " days" if days > 1 else "in "+ str(days ) + " day" if days > 365:
if hours > 0 : return "in "+ str(hours) + " hours" if hours > 1 else "in "+ str(hours) + " hour" return "in greater than a year"
if mins > 0 : return "in "+ str(mins ) + " minutes" if mins > 1 else "in "+ str(mins ) + " minute" if days > 0:
return "in "+ str(secs ) + " seconds" if secs >= 1 or secs == 0 else "in "+ str(secs ) + " second" return (
if days > 730: return "over two years ago" "in " + str(days) + " days" if days > 1 else "in " + str(days) + " day"
if days > 365: return "over a year ago" )
if days > 0 : return str(days ) + " days ago" if days > 1 else str(days ) + " day ago" if hours > 0:
if hours > 0 : return str(hours) + " hours ago" if hours > 1 else str(hours) + " hour ago" return (
if mins > 0 : return str(mins ) + " minutes ago" if mins > 1 else str(mins ) + " minute ago" "in " + str(hours) + " hours"
return str(secs ) + " seconds ago" if secs >= 1 or secs == 0 else str(secs ) + " second ago" if hours > 1
else "in " + str(hours) + " hour"
)
if mins > 0:
return (
"in " + str(mins) + " minutes"
if mins > 1
else "in " + str(mins) + " minute"
)
return (
"in " + str(secs) + " seconds"
if secs >= 1 or secs == 0
else "in " + str(secs) + " second"
)
if days > 730:
return "over two years ago"
if days > 365:
return "over a year ago"
if days > 0:
return str(days) + " days ago" if days > 1 else str(days) + " day ago"
if hours > 0:
return str(hours) + " hours ago" if hours > 1 else str(hours) + " hour ago"
if mins > 0:
return str(mins) + " minutes ago" if mins > 1 else str(mins) + " minute ago"
return (
str(secs) + " seconds ago"
if secs >= 1 or secs == 0
else str(secs) + " second ago"
)
def text_color_duration(duration): def text_color_duration(duration):
"""Prints a color based on duratioin (imported as seconds)""" """Prints a color based on duratioin (imported as seconds)"""
days, seconds = duration.days, duration.seconds days, seconds = duration.days, duration.seconds
hours = (days * 24 + seconds // 3600) hours = days * 24 + seconds // 3600
mins = ((seconds % 3600) // 60) mins = (seconds % 3600) // 60
secs = (seconds % 60) secs = seconds % 60
if days > 30: return "grey-text " if days > 30:
if days > 14: return "red-text text-darken-2 " return "grey-text "
if days > 5: return "deep-orange-text text-lighten-1" if days > 14:
if days > 1: return "deep-orange-text text-lighten-1" return "red-text text-darken-2 "
if hours > 12: return "orange-text " if days > 5:
if hours > 1: return "orange-text text-lighten-2" return "deep-orange-text text-lighten-1"
if hours == 1: return "yellow-text " if days > 1:
if mins > 15: return "yellow-text text-lighten-2" return "deep-orange-text text-lighten-1"
if mins > 5: return "green-text text-lighten-3" if hours > 12:
if secs > 30: return "green-text text-lighten-2" return "orange-text "
if hours > 1:
return "orange-text text-lighten-2"
if hours == 1:
return "yellow-text "
if mins > 15:
return "yellow-text text-lighten-2"
if mins > 5:
return "green-text text-lighten-3"
if secs > 30:
return "green-text text-lighten-2"
return "green-text " return "green-text "
def key_check(): def key_check():
"""Checks the validity of a Headsclae API key and renews it if it's nearing expiration""" """Checks the validity of a Headsclae API key and renews it if it's nearing expiration"""
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
@@ -62,7 +111,10 @@ def key_check():
app.logger.info("Testing API key validity.") app.logger.info("Testing API key validity.")
status = headscale.test_api_key(url, api_key) status = headscale.test_api_key(url, api_key)
if status != 200: if status != 200:
app.logger.info("Got a non-200 response from Headscale. Test failed (Response: %i)", status) app.logger.info(
"Got a non-200 response from Headscale. Test failed (Response: %i)",
status,
)
return False return False
else: else:
app.logger.info("Key check passed.") app.logger.info("Key check passed.")
@@ -70,6 +122,7 @@ def key_check():
headscale.renew_api_key(url, api_key) headscale.renew_api_key(url, api_key)
return True return True
def get_color(import_id, item_type=""): def get_color(import_id, item_type=""):
"""Sets colors for users/namespaces""" """Sets colors for users/namespaces"""
# Define the colors... Seems like a good number to start with # Define the colors... Seems like a good number to start with
@@ -122,6 +175,7 @@ def get_color(import_id, item_type = ""):
index = import_id % len(colors) index = import_id % len(colors)
return colors[index] return colors[index]
def format_message(error_type, title, message): def format_message(error_type, title, message):
"""Defines a generic 'collection' as error/warning/info messages""" """Defines a generic 'collection' as error/warning/info messages"""
content = """ content = """
@@ -144,13 +198,17 @@ def format_message(error_type, title, message):
title = """<span class="title">Information - """ + title + """</span>""" title = """<span class="title">Information - """ + title + """</span>"""
content = content + icon + title + message content = content + icon + title + message
content = content+""" content = (
content
+ """
</li> </li>
</ul> </ul>
""" """
)
return content return content
def access_checks(): def access_checks():
"""Checks various items before each page load to ensure permissions are correct""" """Checks various items before each page load to ensure permissions are correct"""
url = headscale.get_url() url = headscale.get_url()
@@ -166,7 +224,6 @@ def access_checks():
file_exists = False # Checks if /data/key.txt exists file_exists = False # Checks if /data/key.txt exists
config_readable = False # Checks if the headscale configuration file is readable config_readable = False # Checks if the headscale configuration file is readable
# Check 1: Check: the Headscale server is reachable: # Check 1: Check: the Headscale server is reachable:
server_reachable = False server_reachable = False
response = requests.get(str(url) + "/health") response = requests.get(str(url) + "/health")
@@ -177,35 +234,43 @@ def access_checks():
app.logger.critical("Headscale URL: Response 200: FAILED") app.logger.critical("Headscale URL: Response 200: FAILED")
# Check: /data is rwx for 1000:1000: # Check: /data is rwx for 1000:1000:
if os.access('/data/', os.R_OK): data_readable = True if os.access("/data/", os.R_OK):
data_readable = True
else: else:
app.logger.critical("/data READ: FAILED") app.logger.critical("/data READ: FAILED")
checks_passed = False checks_passed = False
if os.access('/data/', os.W_OK): data_writable = True if os.access("/data/", os.W_OK):
data_writable = True
else: else:
app.logger.critical("/data WRITE: FAILED") app.logger.critical("/data WRITE: FAILED")
checks_passed = False checks_passed = False
if os.access('/data/', os.X_OK): data_executable = True if os.access("/data/", os.X_OK):
data_executable = True
else: else:
app.logger.critical("/data EXEC: FAILED") app.logger.critical("/data EXEC: FAILED")
checks_passed = False checks_passed = False
# Check: /data/key.txt exists and is rw: # Check: /data/key.txt exists and is rw:
if os.access('/data/key.txt', os.F_OK): if os.access("/data/key.txt", os.F_OK):
file_exists = True file_exists = True
if os.access('/data/key.txt', os.R_OK): file_readable = True if os.access("/data/key.txt", os.R_OK):
file_readable = True
else: else:
app.logger.critical("/data/key.txt READ: FAILED") app.logger.critical("/data/key.txt READ: FAILED")
checks_passed = False checks_passed = False
if os.access('/data/key.txt', os.W_OK): file_writable = True if os.access("/data/key.txt", os.W_OK):
file_writable = True
else: else:
app.logger.critical("/data/key.txt WRITE: FAILED") app.logger.critical("/data/key.txt WRITE: FAILED")
checks_passed = False checks_passed = False
else: app.logger.error("/data/key.txt EXIST: FAILED - NO ERROR") else:
app.logger.error("/data/key.txt EXIST: FAILED - NO ERROR")
# Check: /etc/headscale/config.yaml is readable: # Check: /etc/headscale/config.yaml is readable:
if os.access('/etc/headscale/config.yaml', os.R_OK): config_readable = True if os.access("/etc/headscale/config.yaml", os.R_OK):
elif os.access('/etc/headscale/config.yml', os.R_OK): config_readable = True config_readable = True
elif os.access("/etc/headscale/config.yml", os.R_OK):
config_readable = True
else: else:
app.logger.error("/etc/headscale/config.y(a)ml: READ: FAILED") app.logger.error("/etc/headscale/config.y(a)ml: READ: FAILED")
checks_passed = False checks_passed = False
@@ -218,11 +283,17 @@ def access_checks():
# Generate the message: # Generate the message:
if not server_reachable: if not server_reachable:
app.logger.critical("Server is unreachable") app.logger.critical("Server is unreachable")
message = """ message = (
"""
<p>Your headscale server is either unreachable or not properly configured. <p>Your headscale server is either unreachable or not properly configured.
Please ensure your configuration is correct (Check for 200 status on Please ensure your configuration is correct (Check for 200 status on
"""+url+"""/api/v1 failed. Response: """+str(response.status_code)+""".)</p>
""" """
+ url
+ """/api/v1 failed. Response: """
+ str(response.status_code)
+ """.)</p>
"""
)
message_html += format_message("Error", "Headscale unreachable", message) message_html += format_message("Error", "Headscale unreachable", message)
@@ -234,7 +305,9 @@ def access_checks():
is named "config.yaml" or "config.yml"</p> is named "config.yaml" or "config.yml"</p>
""" """
message_html += format_message("Error", "/etc/headscale/config.yaml not readable", message) message_html += format_message(
"Error", "/etc/headscale/config.yaml not readable", message
)
if not data_writable: if not data_writable:
app.logger.critical("/data folder is not writable") app.logger.critical("/data folder is not writable")
@@ -266,7 +339,6 @@ def access_checks():
message_html += format_message("Error", "/data not executable", message) message_html += format_message("Error", "/data not executable", message)
if file_exists: if file_exists:
# If it doesn't exist, we assume the user hasn't created it yet. # If it doesn't exist, we assume the user hasn't created it yet.
# Just redirect to the settings page to enter an API Key # Just redirect to the settings page to enter an API Key
@@ -278,7 +350,9 @@ def access_checks():
by UID/GID 1000:1000.</p> by UID/GID 1000:1000.</p>
""" """
message_html += format_message("Error", "/data/key.txt not writable", message) message_html += format_message(
"Error", "/data/key.txt not writable", message
)
if not file_readable: if not file_readable:
app.logger.critical("/data/key.txt is not readable") app.logger.critical("/data/key.txt is not readable")
@@ -288,14 +362,19 @@ def access_checks():
by UID/GID 1000:1000.</p> by UID/GID 1000:1000.</p>
""" """
message_html += format_message("Error", "/data/key.txt not readable", message) message_html += format_message(
"Error", "/data/key.txt not readable", message
)
return message_html return message_html
def load_checks(): def load_checks():
"""Bundles all the checks into a single function to call easier""" """Bundles all the checks into a single function to call easier"""
# General error checks. See the function for more info: # General error checks. See the function for more info:
if access_checks() != "Pass": return 'error_page' if access_checks() != "Pass":
return "error_page"
# If the API key fails, redirect to the settings page: # If the API key fails, redirect to the settings page:
if not key_check(): return 'settings_page' if not key_check():
return "settings_page"
return "Pass" return "Pass"

File diff suppressed because it is too large Load Diff

296
server.py
View File

@@ -1,35 +1,53 @@
# pylint: disable=wrong-import-order import json
import logging
import headscale, helper, json, os, pytz, renderer, secrets, requests, logging import os
from functools import wraps import secrets
from datetime import datetime from datetime import datetime
from flask import Flask, escape, Markup, redirect, render_template, request, url_for from functools import wraps
import pytz
import requests
from dateutil import parser from dateutil import parser
from flask import Flask, Markup, escape, redirect, render_template, request, url_for
from flask_executor import Executor from flask_executor import Executor
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
import headscale
import helper
import renderer
# Global vars # Global vars
# Colors: https://materializecss.com/color.html # Colors: https://materializecss.com/color.html
COLOR = os.environ["COLOR"].replace('"', '').lower() COLOR = os.environ["COLOR"].replace('"', "").lower()
COLOR_NAV = COLOR + " darken-1" COLOR_NAV = COLOR + " darken-1"
COLOR_BTN = COLOR + " darken-3" COLOR_BTN = COLOR + " darken-3"
AUTH_TYPE = os.environ["AUTH_TYPE"].replace('"', '').lower() AUTH_TYPE = os.environ["AUTH_TYPE"].replace('"', "").lower()
LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', '').upper() LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', "").upper()
# If LOG_LEVEL is DEBUG, enable Flask debugging: # If LOG_LEVEL is DEBUG, enable Flask debugging:
DEBUG_STATE = True if LOG_LEVEL == "DEBUG" else False DEBUG_STATE = True if LOG_LEVEL == "DEBUG" else False
# Initiate the Flask application and logging: # Initiate the Flask application and logging:
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
match LOG_LEVEL: match LOG_LEVEL:
case "DEBUG" : app.logger.setLevel(logging.DEBUG) case "DEBUG":
case "INFO" : app.logger.setLevel(logging.INFO) app.logger.setLevel(logging.DEBUG)
case "WARNING" : app.logger.setLevel(logging.WARNING) case "INFO":
case "ERROR" : app.logger.setLevel(logging.ERROR) app.logger.setLevel(logging.INFO)
case "CRITICAL": app.logger.setLevel(logging.CRITICAL) case "WARNING":
app.logger.setLevel(logging.WARNING)
case "ERROR":
app.logger.setLevel(logging.ERROR)
case "CRITICAL":
app.logger.setLevel(logging.CRITICAL)
executor = Executor(app) executor = Executor(app)
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1)
app.logger.info("Headscale-WebUI Version: "+os.environ["APP_VERSION"]+" / "+os.environ["GIT_BRANCH"]) app.logger.info(
"Headscale-WebUI Version: "
+ os.environ["APP_VERSION"]
+ " / "
+ os.environ["GIT_BRANCH"]
)
app.logger.info("LOG LEVEL SET TO %s", str(LOG_LEVEL)) app.logger.info("LOG LEVEL SET TO %s", str(LOG_LEVEL))
app.logger.info("DEBUG STATE: %s", str(DEBUG_STATE)) app.logger.info("DEBUG STATE: %s", str(DEBUG_STATE))
@@ -76,19 +94,22 @@ if AUTH_TYPE == "oidc":
app.logger.debug("/app/instances/secrets.json:") app.logger.debug("/app/instances/secrets.json:")
app.logger.debug(secrets_json.read()) app.logger.debug(secrets_json.read())
app.config.update({ app.config.update(
'SECRET_KEY': secrets.token_urlsafe(32), {
'TESTING': DEBUG_STATE, "SECRET_KEY": secrets.token_urlsafe(32),
'DEBUG': DEBUG_STATE, "TESTING": DEBUG_STATE,
'OIDC_CLIENT_SECRETS': '/app/instance/secrets.json', "DEBUG": DEBUG_STATE,
'OIDC_ID_TOKEN_COOKIE_SECURE': True, "OIDC_CLIENT_SECRETS": "/app/instance/secrets.json",
'OIDC_REQUIRE_VERIFIED_EMAIL': False, "OIDC_ID_TOKEN_COOKIE_SECURE": True,
'OIDC_USER_INFO_ENABLED': True, "OIDC_REQUIRE_VERIFIED_EMAIL": False,
'OIDC_OPENID_REALM': 'Headscale-WebUI', "OIDC_USER_INFO_ENABLED": True,
'OIDC_SCOPES': ['openid', 'profile', 'email'], "OIDC_OPENID_REALM": "Headscale-WebUI",
'OIDC_INTROSPECTION_AUTH_METHOD': 'client_secret_post' "OIDC_SCOPES": ["openid", "profile", "email"],
}) "OIDC_INTROSPECTION_AUTH_METHOD": "client_secret_post",
}
)
from flask_oidc import OpenIDConnect from flask_oidc import OpenIDConnect
oidc = OpenIDConnect(app) oidc = OpenIDConnect(app)
elif AUTH_TYPE == "basic": elif AUTH_TYPE == "basic":
@@ -96,23 +117,26 @@ elif AUTH_TYPE == "basic":
app.logger.info("Loading basic auth libraries and configuring app...") app.logger.info("Loading basic auth libraries and configuring app...")
from flask_basicauth import BasicAuth from flask_basicauth import BasicAuth
app.config['BASIC_AUTH_USERNAME'] = os.environ["BASIC_AUTH_USER"].replace('"', '') app.config["BASIC_AUTH_USERNAME"] = os.environ["BASIC_AUTH_USER"].replace('"', "")
app.config['BASIC_AUTH_PASSWORD'] = os.environ["BASIC_AUTH_PASS"] app.config["BASIC_AUTH_PASSWORD"] = os.environ["BASIC_AUTH_PASS"]
app.config['BASIC_AUTH_FORCE'] = True app.config["BASIC_AUTH_FORCE"] = True
basic_auth = BasicAuth(app) basic_auth = BasicAuth(app)
######################################################################################## ########################################################################################
# Set Authentication type - Dynamically load function decorators # Set Authentication type - Dynamically load function decorators
# https://stackoverflow.com/questions/17256602/assertionerror-view-function-mapping-is-overwriting-an-existing-endpoint-functi # https://stackoverflow.com/questions/17256602/assertionerror-view-function-mapping-is-overwriting-an-existing-endpoint-functi
######################################################################################## ########################################################################################
# Make a fake decorator for oidc.require_login # Make a fake decorator for oidc.require_login
# If anyone knows a better way of doing this, please let me know. # If anyone knows a better way of doing this, please let me know.
class OpenIDConnect(): class OpenIDConnect:
def require_login(self, view_func): def require_login(self, view_func):
@wraps(view_func) @wraps(view_func)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
return decorated return decorated
oidc = OpenIDConnect() oidc = OpenIDConnect()
else: else:
@@ -122,24 +146,28 @@ else:
######################################################################################## ########################################################################################
# Make a fake decorator for oidc.require_login # Make a fake decorator for oidc.require_login
# If anyone knows a better way of doing this, please let me know. # If anyone knows a better way of doing this, please let me know.
class OpenIDConnect(): class OpenIDConnect:
def require_login(self, view_func): def require_login(self, view_func):
@wraps(view_func) @wraps(view_func)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
return decorated return decorated
oidc = OpenIDConnect() oidc = OpenIDConnect()
######################################################################################## ########################################################################################
# / pages - User-facing pages # / pages - User-facing pages
######################################################################################## ########################################################################################
@app.route('/') @app.route("/")
@app.route('/overview') @app.route("/overview")
@oidc.require_login @oidc.require_login
def overview_page(): def overview_page():
# Some basic sanity checks: # Some basic sanity checks:
pass_checks = str(helper.load_checks()) pass_checks = str(helper.load_checks())
if pass_checks != "Pass": return redirect(url_for(pass_checks)) if pass_checks != "Pass":
return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons: # Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("") OIDC_NAV_DROPDOWN = Markup("")
@@ -151,20 +179,23 @@ def overview_page():
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name) OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name) OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
return render_template('overview.html', return render_template(
"overview.html",
render_page=renderer.render_overview(), render_page=renderer.render_overview(),
COLOR_NAV=COLOR_NAV, COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN, COLOR_BTN=COLOR_BTN,
OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN, OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE = OIDC_NAV_MOBILE OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
) )
@app.route('/routes', methods=('GET', 'POST'))
@app.route("/routes", methods=("GET", "POST"))
@oidc.require_login @oidc.require_login
def routes_page(): def routes_page():
# Some basic sanity checks: # Some basic sanity checks:
pass_checks = str(helper.load_checks()) pass_checks = str(helper.load_checks())
if pass_checks != "Pass": return redirect(url_for(pass_checks)) if pass_checks != "Pass":
return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons: # Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("") OIDC_NAV_DROPDOWN = Markup("")
@@ -177,21 +208,23 @@ def routes_page():
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name) OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name) OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
return render_template('routes.html', return render_template(
"routes.html",
render_page=renderer.render_routes(), render_page=renderer.render_routes(),
COLOR_NAV=COLOR_NAV, COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN, COLOR_BTN=COLOR_BTN,
OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN, OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE = OIDC_NAV_MOBILE OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
) )
@app.route('/machines', methods=('GET', 'POST')) @app.route("/machines", methods=("GET", "POST"))
@oidc.require_login @oidc.require_login
def machines_page(): def machines_page():
# Some basic sanity checks: # Some basic sanity checks:
pass_checks = str(helper.load_checks()) pass_checks = str(helper.load_checks())
if pass_checks != "Pass": return redirect(url_for(pass_checks)) if pass_checks != "Pass":
return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons: # Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("") OIDC_NAV_DROPDOWN = Markup("")
@@ -205,22 +238,25 @@ def machines_page():
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name) OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
cards = renderer.render_machines_cards() cards = renderer.render_machines_cards()
return render_template('machines.html', return render_template(
"machines.html",
cards=cards, cards=cards,
headscale_server=headscale.get_url(True), headscale_server=headscale.get_url(True),
COLOR_NAV=COLOR_NAV, COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN, COLOR_BTN=COLOR_BTN,
OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN, OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE=OIDC_NAV_MOBILE, OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
INPAGE_SEARCH = INPAGE_SEARCH INPAGE_SEARCH=INPAGE_SEARCH,
) )
@app.route('/users', methods=('GET', 'POST'))
@app.route("/users", methods=("GET", "POST"))
@oidc.require_login @oidc.require_login
def users_page(): def users_page():
# Some basic sanity checks: # Some basic sanity checks:
pass_checks = str(helper.load_checks()) pass_checks = str(helper.load_checks())
if pass_checks != "Pass": return redirect(url_for(pass_checks)) if pass_checks != "Pass":
return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons: # Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("") OIDC_NAV_DROPDOWN = Markup("")
@@ -234,16 +270,18 @@ def users_page():
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name) OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
cards = renderer.render_users_cards() cards = renderer.render_users_cards()
return render_template('users.html', return render_template(
"users.html",
cards=cards, cards=cards,
COLOR_NAV=COLOR_NAV, COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN, COLOR_BTN=COLOR_BTN,
OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN, OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE=OIDC_NAV_MOBILE, OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
INPAGE_SEARCH = INPAGE_SEARCH INPAGE_SEARCH=INPAGE_SEARCH,
) )
@app.route('/settings', methods=('GET', 'POST'))
@app.route("/settings", methods=("GET", "POST"))
@oidc.require_login @oidc.require_login
def settings_page(): def settings_page():
# Some basic sanity checks: # Some basic sanity checks:
@@ -261,9 +299,16 @@ def settings_page():
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name) OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name) OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
GIT_COMMIT_LINK = Markup("<a href='https://github.com/iFargle/headscale-webui/commit/"+os.environ["GIT_COMMIT"]+"'>"+str(os.environ["GIT_COMMIT"])[0:7]+"</a>") GIT_COMMIT_LINK = Markup(
"<a href='https://github.com/iFargle/headscale-webui/commit/"
+ os.environ["GIT_COMMIT"]
+ "'>"
+ str(os.environ["GIT_COMMIT"])[0:7]
+ "</a>"
)
return render_template('settings.html', return render_template(
"settings.html",
url=headscale.get_url(), url=headscale.get_url(),
COLOR_NAV=COLOR_NAV, COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN, COLOR_BTN=COLOR_BTN,
@@ -273,24 +318,26 @@ def settings_page():
APP_VERSION=os.environ["APP_VERSION"], APP_VERSION=os.environ["APP_VERSION"],
GIT_COMMIT=GIT_COMMIT_LINK, GIT_COMMIT=GIT_COMMIT_LINK,
GIT_BRANCH=os.environ["GIT_BRANCH"], GIT_BRANCH=os.environ["GIT_BRANCH"],
HS_VERSION = os.environ["HS_VERSION"] HS_VERSION=os.environ["HS_VERSION"],
) )
@app.route('/error')
@app.route("/error")
@oidc.require_login @oidc.require_login
def error_page(): def error_page():
if helper.access_checks() == "Pass": if helper.access_checks() == "Pass":
return redirect(url_for('overview_page')) return redirect(url_for("overview_page"))
return render_template('error.html', return render_template("error.html", ERROR_MESSAGE=Markup(helper.access_checks()))
ERROR_MESSAGE = Markup(helper.access_checks())
)
@app.route('/logout')
@app.route("/logout")
def logout_page(): def logout_page():
if AUTH_TYPE == "oidc": if AUTH_TYPE == "oidc":
oidc.logout() oidc.logout()
return redirect(url_for('overview_page')) return redirect(url_for("overview_page"))
######################################################################################## ########################################################################################
# /api pages # /api pages
######################################################################################## ########################################################################################
@@ -299,7 +346,8 @@ def logout_page():
# Headscale API Key Endpoints # Headscale API Key Endpoints
######################################################################################## ########################################################################################
@app.route('/api/test_key', methods=('GET', 'POST'))
@app.route("/api/test_key", methods=("GET", "POST"))
@oidc.require_login @oidc.require_login
def test_key_page(): def test_key_page():
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
@@ -307,14 +355,16 @@ def test_key_page():
# Test the API key. If the test fails, return a failure. # Test the API key. If the test fails, return a failure.
status = headscale.test_api_key(url, api_key) status = headscale.test_api_key(url, api_key)
if status != 200: return "Unauthenticated" if status != 200:
return "Unauthenticated"
renewed = headscale.renew_api_key(url, api_key) renewed = headscale.renew_api_key(url, api_key)
app.logger.warning("The below statement will be TRUE if the key has been renewed, ") app.logger.warning("The below statement will be TRUE if the key has been renewed, ")
app.logger.warning("or DOES NOT need renewal. False in all other cases") app.logger.warning("or DOES NOT need renewal. False in all other cases")
app.logger.warning("Renewed: " + str(renewed)) app.logger.warning("Renewed: " + str(renewed))
# The key works, let's renew it if it needs it. If it does, re-read the api_key from the file: # The key works, let's renew it if it needs it. If it does, re-read the api_key from the file:
if renewed: api_key = headscale.get_api_key() if renewed:
api_key = headscale.get_api_key()
key_info = headscale.get_api_key_info(url, api_key) key_info = headscale.get_api_key_info(url, api_key)
@@ -323,32 +373,47 @@ def test_key_page():
local_time = timezone.localize(datetime.now()) local_time = timezone.localize(datetime.now())
# Format the dates for easy readability # Format the dates for easy readability
creation_parse = parser.parse(key_info['createdAt']) creation_parse = parser.parse(key_info["createdAt"])
creation_local = creation_parse.astimezone(timezone) creation_local = creation_parse.astimezone(timezone)
creation_delta = local_time - creation_local creation_delta = local_time - creation_local
creation_print = helper.pretty_print_duration(creation_delta) creation_print = helper.pretty_print_duration(creation_delta)
creation_time = str(creation_local.strftime('%A %m/%d/%Y, %H:%M:%S'))+" "+str(timezone)+" ("+str(creation_print)+")" creation_time = (
str(creation_local.strftime("%A %m/%d/%Y, %H:%M:%S"))
+ " "
+ str(timezone)
+ " ("
+ str(creation_print)
+ ")"
)
expiration_parse = parser.parse(key_info['expiration']) expiration_parse = parser.parse(key_info["expiration"])
expiration_local = expiration_parse.astimezone(timezone) expiration_local = expiration_parse.astimezone(timezone)
expiration_delta = expiration_local - local_time expiration_delta = expiration_local - local_time
expiration_print = helper.pretty_print_duration(expiration_delta, "expiry") expiration_print = helper.pretty_print_duration(expiration_delta, "expiry")
expiration_time = str(expiration_local.strftime('%A %m/%d/%Y, %H:%M:%S'))+" "+str(timezone)+" ("+str(expiration_print)+")" expiration_time = (
str(expiration_local.strftime("%A %m/%d/%Y, %H:%M:%S"))
+ " "
+ str(timezone)
+ " ("
+ str(expiration_print)
+ ")"
)
key_info['expiration'] = expiration_time key_info["expiration"] = expiration_time
key_info['createdAt'] = creation_time key_info["createdAt"] = creation_time
message = json.dumps(key_info) message = json.dumps(key_info)
return message return message
@app.route('/api/save_key', methods=['POST'])
@app.route("/api/save_key", methods=["POST"])
@oidc.require_login @oidc.require_login
def save_key_page(): def save_key_page():
json_response = request.get_json() json_response = request.get_json()
api_key = json_response['api_key'] api_key = json_response["api_key"]
url = headscale.get_url() url = headscale.get_url()
file_written = headscale.set_api_key(api_key) file_written = headscale.set_api_key(api_key)
message = '' message = ""
if file_written: if file_written:
# Re-read the file and get the new API key and test it # Re-read the file and get the new API key and test it
@@ -356,127 +421,140 @@ def save_key_page():
test_status = headscale.test_api_key(url, api_key) test_status = headscale.test_api_key(url, api_key)
if test_status == 200: if test_status == 200:
key_info = headscale.get_api_key_info(url, api_key) key_info = headscale.get_api_key_info(url, api_key)
expiration = key_info['expiration'] expiration = key_info["expiration"]
message = "Key: '" + api_key + "', Expiration: " + expiration message = "Key: '" + api_key + "', Expiration: " + expiration
# If the key was saved successfully, test it: # If the key was saved successfully, test it:
return "Key saved and tested: " + message return "Key saved and tested: " + message
else: return "Key failed testing. Check your key" else:
else: return "Key did not save properly. Check logs" return "Key failed testing. Check your key"
else:
return "Key did not save properly. Check logs"
######################################################################################## ########################################################################################
# Machine API Endpoints # Machine API Endpoints
######################################################################################## ########################################################################################
@app.route('/api/update_route', methods=['POST']) @app.route("/api/update_route", methods=["POST"])
@oidc.require_login @oidc.require_login
def update_route_page(): def update_route_page():
json_response = request.get_json() json_response = request.get_json()
route_id = escape(json_response['route_id']) route_id = escape(json_response["route_id"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
current_state = json_response['current_state'] current_state = json_response["current_state"]
return headscale.update_route(url, api_key, route_id, current_state) return headscale.update_route(url, api_key, route_id, current_state)
@app.route('/api/machine_information', methods=['POST'])
@app.route("/api/machine_information", methods=["POST"])
@oidc.require_login @oidc.require_login
def machine_information_page(): def machine_information_page():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.get_machine_info(url, api_key, machine_id) return headscale.get_machine_info(url, api_key, machine_id)
@app.route('/api/delete_machine', methods=['POST'])
@app.route("/api/delete_machine", methods=["POST"])
@oidc.require_login @oidc.require_login
def delete_machine_page(): def delete_machine_page():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.delete_machine(url, api_key, machine_id) return headscale.delete_machine(url, api_key, machine_id)
@app.route('/api/rename_machine', methods=['POST'])
@app.route("/api/rename_machine", methods=["POST"])
@oidc.require_login @oidc.require_login
def rename_machine_page(): def rename_machine_page():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
new_name = escape(json_response['new_name']) new_name = escape(json_response["new_name"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.rename_machine(url, api_key, machine_id, new_name) return headscale.rename_machine(url, api_key, machine_id, new_name)
@app.route('/api/move_user', methods=['POST'])
@app.route("/api/move_user", methods=["POST"])
@oidc.require_login @oidc.require_login
def move_user_page(): def move_user_page():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
new_user = escape(json_response['new_user']) new_user = escape(json_response["new_user"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.move_user(url, api_key, machine_id, new_user) return headscale.move_user(url, api_key, machine_id, new_user)
@app.route('/api/set_machine_tags', methods=['POST'])
@app.route("/api/set_machine_tags", methods=["POST"])
@oidc.require_login @oidc.require_login
def set_machine_tags(): def set_machine_tags():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
machine_tags = json_response['tags_list'] machine_tags = json_response["tags_list"]
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.set_machine_tags(url, api_key, machine_id, machine_tags) return headscale.set_machine_tags(url, api_key, machine_id, machine_tags)
@app.route('/api/register_machine', methods=['POST'])
@app.route("/api/register_machine", methods=["POST"])
@oidc.require_login @oidc.require_login
def register_machine(): def register_machine():
json_response = request.get_json() json_response = request.get_json()
machine_key = escape(json_response['key']) machine_key = escape(json_response["key"])
user = escape(json_response['user']) user = escape(json_response["user"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.register_machine(url, api_key, machine_key, user) return headscale.register_machine(url, api_key, machine_key, user)
######################################################################################## ########################################################################################
# User API Endpoints # User API Endpoints
######################################################################################## ########################################################################################
@app.route('/api/rename_user', methods=['POST']) @app.route("/api/rename_user", methods=["POST"])
@oidc.require_login @oidc.require_login
def rename_user_page(): def rename_user_page():
json_response = request.get_json() json_response = request.get_json()
old_name = escape(json_response['old_name']) old_name = escape(json_response["old_name"])
new_name = escape(json_response['new_name']) new_name = escape(json_response["new_name"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.rename_user(url, api_key, old_name, new_name) return headscale.rename_user(url, api_key, old_name, new_name)
@app.route('/api/add_user', methods=['POST'])
@app.route("/api/add_user", methods=["POST"])
@oidc.require_login @oidc.require_login
def add_user(): def add_user():
json_response = request.get_json() json_response = request.get_json()
user_name = str(escape(json_response['name'])) user_name = str(escape(json_response["name"]))
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
json_string = '{"name": "' + user_name + '"}' json_string = '{"name": "' + user_name + '"}'
return headscale.add_user(url, api_key, json_string) return headscale.add_user(url, api_key, json_string)
@app.route('/api/delete_user', methods=['POST'])
@app.route("/api/delete_user", methods=["POST"])
@oidc.require_login @oidc.require_login
def delete_user(): def delete_user():
json_response = request.get_json() json_response = request.get_json()
user_name = str(escape(json_response['name'])) user_name = str(escape(json_response["name"]))
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.delete_user(url, api_key, user_name) return headscale.delete_user(url, api_key, user_name)
@app.route('/api/get_users', methods=['POST'])
@app.route("/api/get_users", methods=["POST"])
@oidc.require_login @oidc.require_login
def get_users_page(): def get_users_page():
url = headscale.get_url() url = headscale.get_url()
@@ -484,10 +562,11 @@ def get_users_page():
return headscale.get_users(url, api_key) return headscale.get_users(url, api_key)
######################################################################################## ########################################################################################
# Pre-Auth Key API Endpoints # Pre-Auth Key API Endpoints
######################################################################################## ########################################################################################
@app.route('/api/add_preauth_key', methods=['POST']) @app.route("/api/add_preauth_key", methods=["POST"])
@oidc.require_login @oidc.require_login
def add_preauth_key(): def add_preauth_key():
json_response = json.dumps(request.get_json()) json_response = json.dumps(request.get_json())
@@ -496,7 +575,8 @@ def add_preauth_key():
return headscale.add_preauth_key(url, api_key, json_response) return headscale.add_preauth_key(url, api_key, json_response)
@app.route('/api/expire_preauth_key', methods=['POST'])
@app.route("/api/expire_preauth_key", methods=["POST"])
@oidc.require_login @oidc.require_login
def expire_preauth_key(): def expire_preauth_key():
json_response = json.dumps(request.get_json()) json_response = json.dumps(request.get_json())
@@ -505,18 +585,20 @@ def expire_preauth_key():
return headscale.expire_preauth_key(url, api_key, json_response) return headscale.expire_preauth_key(url, api_key, json_response)
@app.route('/api/build_preauthkey_table', methods=['POST'])
@app.route("/api/build_preauthkey_table", methods=["POST"])
@oidc.require_login @oidc.require_login
def build_preauth_key_table(): def build_preauth_key_table():
json_response = request.get_json() json_response = request.get_json()
user_name = str(escape(json_response['name'])) user_name = str(escape(json_response["name"]))
return renderer.build_preauth_key_table(user_name) return renderer.build_preauth_key_table(user_name)
######################################################################################## ########################################################################################
# Route API Endpoints # Route API Endpoints
######################################################################################## ########################################################################################
@app.route('/api/get_routes', methods=['POST']) @app.route("/api/get_routes", methods=["POST"])
@oidc.require_login @oidc.require_login
def get_route_info(): def get_route_info():
url = headscale.get_url() url = headscale.get_url()
@@ -528,5 +610,5 @@ def get_route_info():
######################################################################################## ########################################################################################
# Main thread # Main thread
######################################################################################## ########################################################################################
if __name__ == '__main__': if __name__ == "__main__":
app.run(host="0.0.0.0", debug=DEBUG_STATE) app.run(host="0.0.0.0", debug=DEBUG_STATE)