"""Core objects for interacting with e3 environment definition."""
import copy
import functools
import logging
import pathlib
import pprint
from typing import Dict, List, Optional, Union, Set
import yaml
from yaml.scanner import ScannerError
from e3_build_tools import utils
from e3_build_tools.exceptions import DependencyResolutionError, ModuleExistsException
from e3_build_tools.module import EPICSBaseSource, ModuleSource, RequireSource
[docs]
logger = logging.getLogger(__name__)
[docs]
Specification = Dict[str, Dict[str, Union[str, int, Dict[str, Optional[List[str]]]]]]
[docs]
class BuildDefinition:
"""Data structure to store definition of an e3 environment.
The e3 environment definition includes the list of modules in the e3
environment, and the EPICS base version and require version that the
environment will be built for. It provides a set of functions to manage the
module list.
"""
_input_file_type = "specification"
_input_file_version = 1
def __init__(
self,
name: str,
*,
base_ref: str,
require_ref: str,
modules: Optional[Dict[str, ModuleSource]] = None,
) -> None:
"""Initialise the e3 environment definition."""
[docs]
self.base_ref = base_ref
[docs]
self.require_ref = require_ref
# The provided module dictionary is copied into this object as it will
# be modified.
[docs]
self.modules = {} if modules is None else copy.deepcopy(modules)
# EPICS base and require are added to the provided list of modules as
# they are not included in the input list of modules.
self.add_module(EPICSBaseSource(version=base_ref))
self.add_module(RequireSource(version=require_ref))
@property
[docs]
def base_version(self) -> str:
"""Return e3 base version."""
return self.modules["base"].versions[self.base_ref]["version_string"]
@property
[docs]
def require_version(self) -> str:
"""Return e3 require version."""
return self.modules["require"].versions[self.require_ref]["version_string"]
[docs]
def __str__(self) -> str:
"""Print the list of modules in the environment in YAML format."""
return (
"---\n"
+ yaml.dump(
{
module.name: [ver for ver in module.versions]
for module in self.modules.values()
}
).strip()
+ "\n..."
)
[docs]
def __repr__(self) -> str:
"""Return the serialized object."""
non_base_require_modules = {
key: val
for key, val in self.modules.items()
if key not in ("base", "require")
}
return (
f"{self.__class__.__name__}"
f"('{self.name}', "
f"base_ref='{self.base_ref}', "
f"require_ref='{self.require_ref}', "
f"modules={non_base_require_modules})"
)
[docs]
def add_module(self, module: ModuleSource) -> None:
"""Add a module to the definition.
Raises:
ModuleExistsException: If the module name is already in the existing
modules.
"""
if module.name in self.modules:
raise ModuleExistsException
self.modules[module.name] = module
[docs]
def remove_module(self, module_name: str) -> None:
"""Remove a module from the environment."""
logger.debug(f"Removing module {module_name}")
del self.modules[module_name]
[docs]
def to_specification(self, path: Optional[pathlib.Path] = None) -> None:
"""Save the current definition to file."""
if path is None:
path = pathlib.Path(self.name + "_spec").with_suffix(".yml")
logger.debug(
f"Saving EnvironmentDefinition to specification-file '{path.absolute()!s}'"
)
specification: Specification = {
"metadata": {
"type": "specification",
"version": 1,
},
"config": {
"base": self.base_ref,
"require": self.require_ref,
},
"modules": {},
}
for name, module in self.modules.items():
if name in ("base", "require"):
continue
specification["modules"][name] = {
"versions": [name for name in module.versions]
}
with open(path, "w") as f:
f.write(yaml.safe_dump(specification))
@classmethod
[docs]
def from_specification(cls, path: pathlib.Path) -> "BuildDefinition":
"""Construct an `EnvironmentDefinition` from a specification.
Where a specification is a yaml-file with a specific structure
listing the contents of an e3 environment.
Raises:
ModuleExistsException: If the module name is already in the existing
modules.
OSError: If a failure occurs during file opening.
TypeError: If input file does not match expected syntax.
"""
with open(path, "r") as f:
try:
data = yaml.safe_load(f)
except ScannerError:
raise TypeError(f"Failure scanning file {f}.")
cls._validate_data(data)
logger.debug(f"Creating definition from specification '{path.absolute()!s}'")
try:
base_ref = data["config"]["base"]
require_ref = data["config"]["require"]
instance = cls(
name=path.stem,
base_ref=base_ref,
require_ref=require_ref,
)
for name, info in data["modules"].items():
instance.add_module(ModuleSource(name, versions=info["versions"]))
except KeyError as e:
raise TypeError(f"Missing entry for {str(e)} in file.")
return instance
@classmethod
def _validate_data(cls, data):
"""Validate input file data.
Raises:
TypeError: If input file does not match expected syntax.
"""
try:
if (
data["metadata"]["type"] != cls._input_file_type
or data["metadata"]["version"] != cls._input_file_version
):
raise TypeError("Invalid input file: metadata mismatch.")
except KeyError as e:
raise TypeError("Invalid input file: metadata key missing.") from e
[docs]
class SpecificationDefinition(BuildDefinition):
"""Data structure to make changes to wrappers to create new environments."""
_input_file_type = "formula"
_input_file_version = 1
def __init__(self, *args, formula, **kwargs):
"""Initialize object."""
super().__init__(*args, **kwargs)
self._formula: Formula = formula
[docs]
def __str__(self) -> str:
"""Print the formula in YAML format."""
return "---\n" + yaml.dump(self._formula).strip() + "\n..."
[docs]
def fetch_starting_reference(self) -> Dict[str, Union[str, Dict[str, str]]]:
"""Return the starting of reference for all modules."""
module_refs: Dict[str, Union[str, Dict[str, str]]] = {}
for module, details in self._formula["modules"].items():
if isinstance(details, dict):
if "starting_ref" in details:
module_refs[module] = details["starting_ref"]
return module_refs
[docs]
def update_module_version(self, module_name: str, version: str) -> None:
"""Update the module-level substitution with a new version."""
self._formula["modules"][module_name] = utils.deep_merge( # type: ignore
self._formula["modules"][module_name], # type: ignore
{ # type: ignore
"substitutions": {
"configure/CONFIG_MODULE": {"E3_MODULE_VERSION": version}
}
},
)
[docs]
def update_global_dependency_from_module(self, module_name: str) -> None:
"""Update the global-level substitution with a new version."""
dep_var = f"{module_name.upper()}_DEP_VERSION"
try:
new_version = self._formula["modules"][module_name]["substitutions"][ # type: ignore
"configure/CONFIG_MODULE" # type: ignore
][
"E3_MODULE_VERSION" # type: ignore
]
except KeyError:
return
logger.debug(
f"Updating global substitutions with {dep_var!r} = {new_version!r}"
)
self._formula = utils.deep_merge( # type: ignore
{"substitutions": {"configure/CONFIG_MODULE": {dep_var: new_version}}}, # type: ignore
self._formula, # type: ignore
)
[docs]
def combine_substitutions(self, name: str) -> Dict[str, Dict[str, str]]:
"""Return a deep merge of all substitutions in loaded formula."""
# we can't easily check for `version_subs` right now since the versions are a list
# this does however not matter for the current PoC since it only has global and
# module level substitutions; i.e. `ref` is ignored
module_subs = self._formula["modules"][name].get("substitutions", {}) # type: ignore
global_subs = self._formula.get("substitutions", {})
substitutions = functools.reduce(utils.deep_merge, (module_subs, global_subs)) # type: ignore
return substitutions # type: ignore
[docs]
def submodule_updates(self, name: str) -> Dict[str, str]:
"""Return the submodule updates for a given wrapper."""
return self._formula["modules"][name].get("submodule_updates", {}) # type: ignore
@classmethod
[docs]
class Resolver:
"""Class for resolving environment data for build purposes."""
def __init__(self, *, ignore_dependencies: bool = False) -> None:
"""Initialize object."""
[docs]
self.ignore_dependencies = ignore_dependencies
[docs]
def get_sorted_order(self, environment: BuildDefinition) -> List[str]:
"""Return a topological sort of the dependency graph.
The returned list is sorted in an order that ensures all dependencies
for a module are built before that module is reached.
The algorithm iterates through the main module list, and then
recursively descends through the dependency tree for each module. If a
module is already in the sorted list, it is not added again. Modules are
added to the sorted list after the recursive call returns so that the
module with the dependency is added to the list after the dependency.
Refer to https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
for more details about the algorithm.
"""
sorted_modules: List[str] = []
modules_with_missing_deps: Dict[str, str] = {}
not_present: Set[str] = set()
tmp_visited = set()
def visit(module: str):
if module in not_present.union(sorted_modules):
return
# The tmp_visited list is maintained to detect circular dependencies.
if module in tmp_visited:
raise DependencyResolutionError(
"Resolution failed due to circular dependencies"
)
try:
deps = environment.modules[module].dependencies
except KeyError:
logger.debug(f"Dependent module '{module}' is not present")
not_present.add(module)
return
tmp_visited.add(module)
for dep in deps:
visit(dep)
tmp_visited.remove(module)
missing_deps = set(deps).difference(sorted_modules)
if missing_deps:
modules_with_missing_deps[module] = ",".join(missing_deps)
sorted_modules.append(module)
for module in environment.modules:
visit(module)
if modules_with_missing_deps:
if not self.ignore_dependencies:
raise DependencyResolutionError(
f"Dependency resolution failure due to missing dependencies for: {pprint.pformat(modules_with_missing_deps)}"
)
else:
logger.debug(
f"Missing dependencies: {pprint.pformat(modules_with_missing_deps)}"
)
logger.debug(
f"Sorted modules by dependency: {[name for name in sorted_modules]}"
)
return sorted_modules