mirror of
https://github.com/ansible-collections/hetzner.hcloud
synced 2024-11-10 06:34:13 +00:00
fix(inventory): improve performance (#402)
##### SUMMARY Improve the performance of the inventory plugin by: - Cache client requests - Move servers `status` filtering to query params.
This commit is contained in:
parent
fb40a00689
commit
f85d8f4492
3 changed files with 134 additions and 87 deletions
7
examples/inventory.hcloud.yml
Normal file
7
examples/inventory.hcloud.yml
Normal file
|
@ -0,0 +1,7 @@
|
|||
# You can list the hosts using:
|
||||
# ansible-inventory --list -i examples/inventory.hcloud.yml --extra-vars=network_name=my-network
|
||||
|
||||
plugin: hetzner.hcloud.hcloud
|
||||
|
||||
network: "{{ network_name }}"
|
||||
status: [running]
|
|
@ -144,8 +144,14 @@ from ansible.module_utils.common.text.converters import to_native
|
|||
from ansible.plugins.inventory import BaseInventoryPlugin, Cacheable, Constructable
|
||||
from ansible.utils.display import Display
|
||||
|
||||
from ..module_utils.client import HAS_DATEUTIL, HAS_REQUESTS
|
||||
from ..module_utils.vendor import hcloud
|
||||
from ..module_utils.client import (
|
||||
Client,
|
||||
ClientException,
|
||||
client_check_required_lib,
|
||||
client_get_by_name_or_id,
|
||||
)
|
||||
from ..module_utils.vendor.hcloud import APIException
|
||||
from ..module_utils.vendor.hcloud.networks import Network
|
||||
from ..module_utils.vendor.hcloud.servers import Server
|
||||
from ..module_utils.version import version
|
||||
|
||||
|
@ -196,13 +202,24 @@ else:
|
|||
InventoryServer = dict
|
||||
|
||||
|
||||
def first_ipv6_address(network: str) -> str:
|
||||
"""
|
||||
Return the first address for a ipv6 network.
|
||||
|
||||
:param network: IPv6 Network.
|
||||
"""
|
||||
return next(IPv6Network(network).hosts())
|
||||
|
||||
|
||||
class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
|
||||
NAME = "hetzner.hcloud.hcloud"
|
||||
|
||||
inventory: InventoryData
|
||||
display: Display
|
||||
|
||||
client: hcloud.Client
|
||||
client: Client
|
||||
|
||||
network: Network | None
|
||||
|
||||
def _configure_hcloud_client(self):
|
||||
# If api_token_env is not the default, print a deprecation warning and load the
|
||||
|
@ -232,7 +249,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
|
|||
# Resolve template string
|
||||
api_token = self.templar.template(api_token)
|
||||
|
||||
self.client = hcloud.Client(
|
||||
self.client = Client(
|
||||
token=api_token,
|
||||
api_endpoint=api_endpoint,
|
||||
application_name="ansible-inventory",
|
||||
|
@ -242,61 +259,47 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
|
|||
try:
|
||||
# Ensure the api token is valid
|
||||
self.client.locations.get_list()
|
||||
except hcloud.APIException as exception:
|
||||
except APIException as exception:
|
||||
raise AnsibleError("Invalid Hetzner Cloud API Token.") from exception
|
||||
|
||||
def _get_servers(self):
|
||||
if len(self.get_option("label_selector")) > 0:
|
||||
self.servers = self.client.servers.get_all(label_selector=self.get_option("label_selector"))
|
||||
else:
|
||||
self.servers = self.client.servers.get_all()
|
||||
|
||||
def _filter_servers(self):
|
||||
def _validate_options(self) -> None:
|
||||
if self.get_option("network"):
|
||||
network = self.templar.template(self.get_option("network"), fail_on_undefined=False) or self.get_option(
|
||||
"network"
|
||||
)
|
||||
network_param: str = self.get_option("network")
|
||||
network_param = self.templar.template(network_param)
|
||||
|
||||
try:
|
||||
self.network = self.client.networks.get_by_name(network)
|
||||
if self.network is None:
|
||||
self.network = self.client.networks.get_by_id(network)
|
||||
except hcloud.APIException:
|
||||
raise AnsibleError("The given network is not found.")
|
||||
self.network = client_get_by_name_or_id(self.client, "networks", network_param)
|
||||
except (ClientException, APIException) as exception:
|
||||
raise AnsibleError(to_native(exception)) from exception
|
||||
|
||||
tmp = []
|
||||
for server in self.servers:
|
||||
for server_private_network in server.private_net:
|
||||
if server_private_network.network.id == self.network.id:
|
||||
tmp.append(server)
|
||||
self.servers = tmp
|
||||
def _fetch_servers(self) -> list[Server]:
|
||||
self._validate_options()
|
||||
|
||||
if self.get_option("locations"):
|
||||
tmp = []
|
||||
for server in self.servers:
|
||||
if server.datacenter.location.name in self.get_option("locations"):
|
||||
tmp.append(server)
|
||||
self.servers = tmp
|
||||
|
||||
if self.get_option("types"):
|
||||
tmp = []
|
||||
for server in self.servers:
|
||||
if server.server_type.name in self.get_option("types"):
|
||||
tmp.append(server)
|
||||
self.servers = tmp
|
||||
|
||||
if self.get_option("images"):
|
||||
tmp = []
|
||||
for server in self.servers:
|
||||
if server.image is not None and server.image.os_flavor in self.get_option("images"):
|
||||
tmp.append(server)
|
||||
self.servers = tmp
|
||||
get_servers_params = {}
|
||||
if self.get_option("label_selector"):
|
||||
get_servers_params["label_selector"] = self.get_option("label_selector")
|
||||
|
||||
if self.get_option("status"):
|
||||
tmp = []
|
||||
for server in self.servers:
|
||||
if server.status in self.get_option("status"):
|
||||
tmp.append(server)
|
||||
self.servers = tmp
|
||||
get_servers_params["status"] = self.get_option("status")
|
||||
|
||||
servers = self.client.servers.get_all(**get_servers_params)
|
||||
|
||||
if self.get_option("network"):
|
||||
servers = [s for s in servers if self.network.id in [p.network.id for p in s.private_net]]
|
||||
|
||||
if self.get_option("locations"):
|
||||
locations: list[str] = self.get_option("locations")
|
||||
servers = [s for s in servers if s.datacenter.location.name in locations]
|
||||
|
||||
if self.get_option("types"):
|
||||
server_types: list[str] = self.get_option("types")
|
||||
servers = [s for s in servers if s.server_type.name in server_types]
|
||||
|
||||
if self.get_option("images"):
|
||||
images: list[str] = self.get_option("images")
|
||||
servers = [s for s in servers if s.image is not None and s.image.os_flavor in images]
|
||||
|
||||
return servers
|
||||
|
||||
def _build_inventory_server(self, server: Server) -> InventoryServer:
|
||||
server_dict: InventoryServer = {}
|
||||
|
@ -311,7 +314,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
|
|||
server_dict["ipv4"] = to_native(server.public_net.ipv4.ip)
|
||||
|
||||
if server.public_net.ipv6:
|
||||
server_dict["ipv6"] = to_native(self._first_ipv6_address(server.public_net.ipv6.ip))
|
||||
server_dict["ipv6"] = to_native(first_ipv6_address(server.public_net.ipv6.ip))
|
||||
server_dict["ipv6_network"] = to_native(server.public_net.ipv6.network)
|
||||
server_dict["ipv6_network_mask"] = to_native(server.public_net.ipv6.network_mask)
|
||||
|
||||
|
@ -320,10 +323,11 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
|
|||
]
|
||||
|
||||
if self.get_option("network"):
|
||||
for server_private_network in server.private_net:
|
||||
for private_net in server.private_net:
|
||||
# Set private_ipv4 if user filtered for one network
|
||||
if server_private_network.network.id == self.network.id:
|
||||
server_dict["private_ipv4"] = to_native(server_private_network.ip)
|
||||
if private_net.network.id == self.network.id:
|
||||
server_dict["private_ipv4"] = to_native(private_net.ip)
|
||||
break
|
||||
|
||||
# Server Type
|
||||
if server.server_type is not None:
|
||||
|
@ -353,60 +357,54 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
|
|||
|
||||
return server_dict
|
||||
|
||||
def _get_server_ansible_host(self, server):
|
||||
def _get_server_ansible_host(self, server: Server):
|
||||
if self.get_option("connect_with") == "public_ipv4":
|
||||
if server.public_net.ipv4:
|
||||
return to_native(server.public_net.ipv4.ip)
|
||||
else:
|
||||
raise AnsibleError("Server has no public ipv4, but connect_with=public_ipv4 was specified")
|
||||
raise AnsibleError("Server has no public ipv4, but connect_with=public_ipv4 was specified")
|
||||
|
||||
if self.get_option("connect_with") == "public_ipv6":
|
||||
if server.public_net.ipv6:
|
||||
return to_native(self._first_ipv6_address(server.public_net.ipv6.ip))
|
||||
else:
|
||||
raise AnsibleError("Server has no public ipv6, but connect_with=public_ipv6 was specified")
|
||||
return to_native(first_ipv6_address(server.public_net.ipv6.ip))
|
||||
raise AnsibleError("Server has no public ipv6, but connect_with=public_ipv6 was specified")
|
||||
|
||||
elif self.get_option("connect_with") == "hostname":
|
||||
if self.get_option("connect_with") == "hostname":
|
||||
# every server has a name, no need to guard this
|
||||
return to_native(server.name)
|
||||
|
||||
elif self.get_option("connect_with") == "ipv4_dns_ptr":
|
||||
if self.get_option("connect_with") == "ipv4_dns_ptr":
|
||||
if server.public_net.ipv4:
|
||||
return to_native(server.public_net.ipv4.dns_ptr)
|
||||
else:
|
||||
raise AnsibleError("Server has no public ipv4, but connect_with=ipv4_dns_ptr was specified")
|
||||
raise AnsibleError("Server has no public ipv4, but connect_with=ipv4_dns_ptr was specified")
|
||||
|
||||
elif self.get_option("connect_with") == "private_ipv4":
|
||||
if self.get_option("connect_with") == "private_ipv4":
|
||||
if self.get_option("network"):
|
||||
for server_private_network in server.private_net:
|
||||
if server_private_network.network.id == self.network.id:
|
||||
return to_native(server_private_network.ip)
|
||||
for private_net in server.private_net:
|
||||
if private_net.network.id == self.network.id:
|
||||
return to_native(private_net.ip)
|
||||
|
||||
else:
|
||||
raise AnsibleError("You can only connect via private IPv4 if you specify a network")
|
||||
|
||||
def _first_ipv6_address(self, network):
|
||||
return next(IPv6Network(network).hosts())
|
||||
|
||||
def verify_file(self, path):
|
||||
"""Return the possibly of a file being consumable by this plugin."""
|
||||
return super().verify_file(path) and path.endswith(("hcloud.yaml", "hcloud.yml"))
|
||||
|
||||
def _get_cached_result(self, path, cache) -> tuple[list[InventoryServer | None], bool]:
|
||||
def _get_cached_result(self, path, cache) -> tuple[list[InventoryServer], bool]:
|
||||
# false when refresh_cache or --flush-cache is used
|
||||
if not cache:
|
||||
return None, False
|
||||
return [], False
|
||||
|
||||
# get the user-specified directive
|
||||
if not self.get_option("cache"):
|
||||
return None, False
|
||||
return [], False
|
||||
|
||||
cache_key = self.get_cache_key(path)
|
||||
try:
|
||||
cached_result = self._cache[cache_key]
|
||||
except KeyError:
|
||||
# if cache expires or cache file doesn"t exist
|
||||
return None, False
|
||||
return [], False
|
||||
|
||||
return cached_result, True
|
||||
|
||||
|
@ -426,24 +424,27 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
|
|||
def parse(self, inventory, loader, path, cache=True):
|
||||
super().parse(inventory, loader, path, cache)
|
||||
|
||||
if not HAS_REQUESTS:
|
||||
raise AnsibleError("The Hetzner Cloud dynamic inventory plugin requires requests.")
|
||||
if not HAS_DATEUTIL:
|
||||
raise AnsibleError("The Hetzner Cloud dynamic inventory plugin requires python-dateutil.")
|
||||
try:
|
||||
client_check_required_lib()
|
||||
except ClientException as exception:
|
||||
raise AnsibleError(to_native(exception)) from exception
|
||||
|
||||
# Allow using extra variables arguments as template variables (e.g.
|
||||
# '--extra-vars my_var=my_value')
|
||||
self.templar.available_variables = self._vars
|
||||
|
||||
self._read_config_data(path)
|
||||
self._configure_hcloud_client()
|
||||
|
||||
self.servers, cached = self._get_cached_result(path, cache)
|
||||
servers, cached = self._get_cached_result(path, cache)
|
||||
if not cached:
|
||||
self._get_servers()
|
||||
self._filter_servers()
|
||||
self.servers = [self._build_inventory_server(server) for server in self.servers]
|
||||
with self.client.cached_session():
|
||||
servers = [self._build_inventory_server(s) for s in self._fetch_servers()]
|
||||
|
||||
# Add a top group
|
||||
self.inventory.add_group(group=self.get_option("group"))
|
||||
|
||||
for server in self.servers:
|
||||
for server in servers:
|
||||
self.inventory.add_host(server["name"], group=self.get_option("group"))
|
||||
for key, value in server.items():
|
||||
self.inventory.set_variable(server["name"], key, value)
|
||||
|
@ -475,4 +476,4 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
|
|||
strict=strict,
|
||||
)
|
||||
|
||||
self._update_cached_result(path, cache, self.servers)
|
||||
self._update_cached_result(path, cache, servers)
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ansible.module_utils.basic import missing_required_lib
|
||||
|
||||
from .vendor.hcloud import APIException, Client
|
||||
from .vendor.hcloud import APIException, Client as ClientBase
|
||||
|
||||
HAS_REQUESTS = True
|
||||
HAS_DATEUTIL = True
|
||||
|
@ -61,3 +63,40 @@ def client_get_by_name_or_id(client: Client, resource: str, param: str | int):
|
|||
if exception.code == "not_found":
|
||||
raise _client_resource_not_found(resource, param) from exception
|
||||
raise exception
|
||||
|
||||
|
||||
if HAS_REQUESTS:
|
||||
|
||||
class CachedSession(requests.Session):
|
||||
cache: dict[str, requests.Response] = {}
|
||||
|
||||
def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Send a given PreparedRequest.
|
||||
"""
|
||||
if request.method != "GET" or request.url is None:
|
||||
return super().send(request, **kwargs)
|
||||
|
||||
if request.url in self.cache:
|
||||
return self.cache[request.url]
|
||||
|
||||
response = super().send(request, **kwargs)
|
||||
if response.ok:
|
||||
self.cache[request.url] = response
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class Client(ClientBase):
|
||||
@contextmanager
|
||||
def cached_session(self) -> None:
|
||||
"""
|
||||
Swap the client session during the scope of the context. The session will cache
|
||||
all GET requests.
|
||||
|
||||
Cached response will not expire, therefore the cached client must not be used
|
||||
for long living scopes.
|
||||
"""
|
||||
self._requests_session = CachedSession()
|
||||
yield
|
||||
self._requests_session = requests.Session()
|
||||
|
|
Loading…
Reference in a new issue