Source code for nornir_sql.plugins.inventory.sql

"""nornir_sql.plugins.inventory.sql"""
import json
import logging
from typing import Optional, Dict, Type, Any, Union

from nornir.core.inventory import (
    Inventory,
    Groups,
    Host,
    Group,
    ParentGroups,
    Hosts,
    HostOrGroup,
    Defaults,
    ConnectionOptions,
)
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
from pathlib import Path
import ruamel.yaml

logger = logging.getLogger("nornir_sql")


def _get_connection_options(data: Union[str, Dict[str, Any]]) -> Dict[str, ConnectionOptions]:
    """Create per-platform ConnectionOptions objects from configuration dict

    Args:
        data (str|dict): Connection options in dict or JSON format

    Returns:
        dict of per-platform connection options
    """
    cp = {}
    if isinstance(data, str):  # convert json to dict
        data = json.loads(data)
    for cn, c in data.items():
        cp[cn] = ConnectionOptions(
            hostname=c.get("hostname"),
            port=c.get("port"),
            username=c.get("username"),
            password=c.get("password"),
            platform=c.get("platform"),
            extras=c.get("extras"),
        )
    return cp


def _get_defaults(data: Optional[Dict[str, Any]] = None) -> Defaults:
    if data is None:
        data = {}
    return Defaults(
        hostname=data.get("hostname"),
        port=data.get("port"),
        username=data.get("username"),
        password=data.get("password"),
        platform=data.get("platform"),
        data=data.get("data"),
        connection_options=_get_connection_options(data.get("connection_options", {})),
    )


[docs]class SQLInventory: """SQLInventory implements SQL inventory plugin for Nornir"""
[docs] def __init__( self, sql_connection: str, hosts_query: str, groups_query: str = "", groups_file: Optional[str] = None, groups: Optional[dict] = None, defaults: Optional[Dict[str, str]] = None, ): """Setup SQLInventory parameters The SQL queries' fields must stick to the naming convention as follows: #. | ``name``: The device name in the inventory #. | ``hostname``: Device's hostname/fqdn/ip which is accessible #. | ``port``: Port on the device is accessible #. | ``username``: Username on the device #. | ``password``: Password on the device #. | ``platform``: Platform to use with the connection #. | ``data.extra1``: Will be put to ``data`` with the name of ``extra1`` #. | ``groups``: Coma separated group names for this host #. | ``connection_options``: JSON formatted connection_options string Args: sql_connection (str): SQL connection string. E.g.: 'mssql+pymssql://@SERVERNAME/DBNAME' hosts_query (str): Query string for getting hosts. All fields must be named as above! groups_query (str): Query string for getting groups. All fields must be named as above! groups_file (str): YAML file path to group definition file. Ignored when groups_query or groups are specified! groups (dict): group definition as dict. Ignored when groups_query is specified! defaults (dict): dict of default values. """ self.hosts_query: str = hosts_query self.groups_query: str = groups_query if groups_file: self.groups_file: Optional[Path] = Path(groups_file).expanduser() else: self.groups_file: Optional[Path] = None self.groups: dict = groups self.defaults: Defaults = _get_defaults(defaults) self.engine = None try: self.engine = create_engine(sql_connection) except SQLAlchemyError as err: logger.error(err) raise err from err
def _get_inventory_element(self, typ: Type[HostOrGroup], data: Dict[str, str]) -> HostOrGroup: """Create a Host or Group object from dict Args: typ: Host or Group type data: dict of elements for the object Returns: Host or Group object """ if isinstance(data.get("groups"), list): # groups come from groups_file groups = data["groups"] else: # groups come from sql as a string groups = data["groups"].replace(" ", "").split(",") if data.get("groups") else [] if isinstance(data.get("data"), dict): # extra data come from groups_file extra_data = data["data"] else: # extra data is provided by SQL extra_data = {extra.split(".")[1]: data.get(extra, "") for extra in data if "data." in extra} ret = typ( name=data.get("name"), hostname=data.get("hostname"), port=data.get("port"), username=data.get("username"), password=data.get("password"), platform=data.get("platform"), # ParentGroups object will be prepared after groups are loaded. Here we note the group names. groups=groups, data=extra_data, defaults=self.defaults, connection_options=_get_connection_options(data.get("connection_options", {})), ) return ret
[docs] def load(self) -> Inventory: """Load inventory from SQL server""" hosts = Hosts() groups = Groups() try: with self.engine.connect() as connection: results = connection.execute(text(self.hosts_query)) for host_data in results: # Convert Row object to dictionary for SQLAlchemy 2.0 compatibility host_dict = {column: host_data[i] for i, column in enumerate(results.keys())} host = self._get_inventory_element(Host, host_dict) hosts[host.name] = host if self.groups_query: results = connection.execute(text(self.groups_query)) for group_data in results: # Convert Row object to dictionary for SQLAlchemy 2.0 compatibility group_dict = {column: group_data[i] for i, column in enumerate(results.keys())} group = self._get_inventory_element(Group, group_dict) groups[group.name] = group elif self.groups: for n, g in self.groups.items(): group_data = {"name": n, **g} group = self._get_inventory_element(Group, group_data) groups[group.name] = group elif self.groups_file: yml = ruamel.yaml.YAML(typ="safe") if self.groups_file.exists(): with open(self.groups_file) as fi: groups_dict = yml.load(fi) or {} for n, g in groups_dict.items(): group_data = {"name": n, **g} group = self._get_inventory_element(Group, group_data) groups[group.name] = group if len(groups) > 0: # replace strings to objects for group in groups.values(): group.groups = ParentGroups([groups[g] for g in group.groups]) for host in hosts.values(): host.groups = ParentGroups([groups[g] for g in host.groups]) except SQLAlchemyError as err: logger.error("SQL error: %s", err) raise err from err return Inventory(hosts=hosts, groups=groups, defaults=self.defaults)