| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Any, TypeVar |
| |
|
| | import attrs |
| |
|
| | from omegaconf import DictConfig as LazyDict |
| |
|
| | from .misc import Color |
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | def _is_attrs_instance(obj: object) -> bool: |
| | """ |
| | Helper function to check if an object is an instance of an attrs-defined class. |
| | |
| | Args: |
| | obj: The object to check. |
| | |
| | Returns: |
| | bool: True if the object is an instance of an attrs-defined class, False otherwise. |
| | """ |
| | return hasattr(obj, "__attrs_attrs__") |
| |
|
| |
|
| | def make_freezable(cls: T) -> T: |
| | """ |
| | A decorator that adds the capability to freeze instances of an attrs-defined class. |
| | |
| | NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need |
| | to hack on a "_is_frozen" attribute. |
| | |
| | This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. |
| | Once an instance is frozen, its attributes cannot be changed. It also recursively freezes |
| | any attrs-defined objects that are attributes of the class. |
| | |
| | Usage: |
| | @make_freezable |
| | @attrs.define(slots=False) |
| | class MyClass: |
| | attribute1: int |
| | attribute2: str |
| | |
| | obj = MyClass(1, 'a') |
| | obj.freeze() # Freeze the instance |
| | obj.attribute1 = 2 # Raises AttributeError |
| | |
| | Args: |
| | cls: The class to be decorated. |
| | |
| | Returns: |
| | The decorated class with added freezing capability. |
| | """ |
| |
|
| | if not hasattr(cls, "__dict__"): |
| | raise TypeError( |
| | "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " |
| | "class was defined with `@attrs.define(slots=False)`" |
| | ) |
| |
|
| | original_setattr = cls.__setattr__ |
| |
|
| | def setattr_override(self, key, value) -> None: |
| | """ |
| | Override __setattr__ to allow modifications during initialization |
| | and prevent modifications once the instance is frozen. |
| | """ |
| | if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": |
| | raise AttributeError("Cannot modify frozen instance") |
| | original_setattr(self, key, value) |
| |
|
| | cls.__setattr__ = setattr_override |
| |
|
| | def freeze(self: object) -> None: |
| | """ |
| | Freeze the instance and all its attrs-defined attributes. |
| | """ |
| | for _, value in attrs.asdict(self, recurse=False).items(): |
| | if _is_attrs_instance(value) and hasattr(value, "freeze"): |
| | value.freeze() |
| | self._is_frozen = True |
| |
|
| | cls.freeze = freeze |
| |
|
| | return cls |
| |
|
| |
|
| | def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: |
| | """ |
| | Recursively pretty prints attrs objects with color. |
| | """ |
| |
|
| | assert attrs.has(obj.__class__) |
| |
|
| | lines: list[str] = [] |
| | for attribute in attrs.fields(obj.__class__): |
| | value = getattr(obj, attribute.name) |
| | if attrs.has(value.__class__): |
| | if use_color: |
| | lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") |
| | else: |
| | lines.append(" " * indent + "* " + attribute.name + ":") |
| | lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) |
| | else: |
| | if use_color: |
| | lines.append( |
| | " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) |
| | ) |
| | else: |
| | lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) |
| | return "\n".join(lines) |
| |
|
| |
|
| | @make_freezable |
| | @attrs.define(slots=False) |
| | class JobConfig: |
| | |
| | project: str = "" |
| | |
| | group: str = "" |
| | |
| | name: str = "" |
| |
|
| | @property |
| | def path(self) -> str: |
| | return f"{self.project}/{self.group}/{self.name}" |
| |
|
| |
|
| | @make_freezable |
| | @attrs.define(slots=False) |
| | class Config: |
| | """Config for a job. |
| | |
| | See /README.md/Configuration System for more info. |
| | """ |
| |
|
| | |
| | model: LazyDict |
| |
|
| | |
| | job: JobConfig = attrs.field(factory=JobConfig) |
| |
|
| | def to_dict(self) -> dict[str, Any]: |
| | return attrs.asdict(self) |
| |
|
| | def validate(self) -> None: |
| | """Validate that the config has all required fields.""" |
| | assert self.job.project != "", "Project name is required." |
| | assert self.job.group != "", "Group name is required." |
| | assert self.job.name != "", "Job name is required." |
| |
|