Skip to content

Commit

Permalink
Add cli_exit_on_error config option (#340)
Browse files Browse the repository at this point in the history
Co-authored-by: Hasan Ramezani <[email protected]>
  • Loading branch information
kschwab and hramezani committed Jul 19, 2024
1 parent bcbdd2a commit 4840d69
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 65 deletions.
39 changes: 31 additions & 8 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,28 @@ options:
"""
```

#### Change Whether CLI Should Exit on Error

Change whether the CLI internal parser will exit on error or raise a `SettingsError` exception by using
`cli_exit_on_error`. By default, the CLI internal parser will exit on error.

```py
import sys

from pydantic_settings import BaseSettings, SettingsError


class Settings(BaseSettings, cli_parse_args=True, cli_exit_on_error=False): ...


try:
sys.argv = ['example.py', '--bad-arg']
Settings()
except SettingsError as e:
print(e)
#> error parsing CLI: unrecognized arguments: --bad-arg
```

#### Enforce Required Arguments at CLI

Pydantic settings is designed to pull values in from various sources when instantating a model. This means a field that
Expand All @@ -884,10 +906,15 @@ import sys

from pydantic import Field

from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsError


class Settings(BaseSettings, cli_parse_args=True, cli_enforce_required=True):
class Settings(
BaseSettings,
cli_parse_args=True,
cli_enforce_required=True,
cli_exit_on_error=False,
):
my_required_field: str = Field(description='a top level required field')


Expand All @@ -896,13 +923,9 @@ os.environ['MY_REQUIRED_FIELD'] = 'hello from environment'
try:
sys.argv = ['example.py']
Settings()
except SystemExit as e:
except SettingsError as e:
print(e)
#> 2
"""
usage: example.py [-h] --my_required_field str
example.py: error: the following arguments are required: --my_required_field
"""
#> error parsing CLI: the following arguments are required: --my_required_field
```

#### Change the None Type Parse String
Expand Down
2 changes: 2 additions & 0 deletions pydantic_settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
PydanticBaseSettingsSource,
PyprojectTomlConfigSettingsSource,
SecretsSettingsSource,
SettingsError,
TomlConfigSettingsSource,
YamlConfigSettingsSource,
)
Expand All @@ -29,6 +30,7 @@
'PydanticBaseSettingsSource',
'SecretsSettingsSource',
'SettingsConfigDict',
'SettingsError',
'TomlConfigSettingsSource',
'YamlConfigSettingsSource',
'AzureKeyVaultSettingsSource',
Expand Down
11 changes: 11 additions & 0 deletions pydantic_settings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class SettingsConfigDict(ConfigDict, total=False):
cli_avoid_json: bool
cli_enforce_required: bool
cli_use_class_docs_for_groups: bool
cli_exit_on_error: bool
cli_prefix: str
secrets_dir: str | Path | None
json_file: PathType | None
Expand Down Expand Up @@ -110,6 +111,8 @@ class BaseSettings(BaseModel):
_cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`.
_cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions.
Defaults to `False`.
_cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
Defaults to `True`.
_cli_prefix: The root parser command line arguments prefix. Defaults to "".
_secrets_dir: The secret files directory. Defaults to `None`.
"""
Expand All @@ -132,6 +135,7 @@ def __init__(
_cli_avoid_json: bool | None = None,
_cli_enforce_required: bool | None = None,
_cli_use_class_docs_for_groups: bool | None = None,
_cli_exit_on_error: bool | None = None,
_cli_prefix: str | None = None,
_secrets_dir: str | Path | None = None,
**values: Any,
Expand All @@ -156,6 +160,7 @@ def __init__(
_cli_avoid_json=_cli_avoid_json,
_cli_enforce_required=_cli_enforce_required,
_cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups,
_cli_exit_on_error=_cli_exit_on_error,
_cli_prefix=_cli_prefix,
_secrets_dir=_secrets_dir,
)
Expand Down Expand Up @@ -204,6 +209,7 @@ def _settings_build_values(
_cli_avoid_json: bool | None = None,
_cli_enforce_required: bool | None = None,
_cli_use_class_docs_for_groups: bool | None = None,
_cli_exit_on_error: bool | None = None,
_cli_prefix: str | None = None,
_secrets_dir: str | Path | None = None,
) -> dict[str, Any]:
Expand Down Expand Up @@ -250,6 +256,9 @@ def _settings_build_values(
if _cli_use_class_docs_for_groups is not None
else self.model_config.get('cli_use_class_docs_for_groups')
)
cli_exit_on_error = (
_cli_exit_on_error if _cli_exit_on_error is not None else self.model_config.get('cli_exit_on_error')
)
cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix')

secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir')
Expand Down Expand Up @@ -300,6 +309,7 @@ def _settings_build_values(
cli_avoid_json=cli_avoid_json,
cli_enforce_required=cli_enforce_required,
cli_use_class_docs_for_groups=cli_use_class_docs_for_groups,
cli_exit_on_error=cli_exit_on_error,
cli_prefix=cli_prefix,
case_sensitive=case_sensitive,
)
Expand Down Expand Up @@ -346,6 +356,7 @@ def _settings_build_values(
cli_avoid_json=False,
cli_enforce_required=False,
cli_use_class_docs_for_groups=False,
cli_exit_on_error=True,
cli_prefix='',
json_file=None,
json_file_encoding=None,
Expand Down
30 changes: 24 additions & 6 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Iterator,
List,
Mapping,
NoReturn,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -110,6 +111,10 @@ def import_azure_key_vault() -> None:
ENV_FILE_SENTINEL: DotenvType = Path('')


class SettingsError(ValueError):
pass


class _CliSubCommand:
pass

Expand All @@ -119,7 +124,14 @@ class _CliPositionalArg:


class _CliInternalArgParser(ArgumentParser):
pass
def __init__(self, cli_exit_on_error: bool = True, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._cli_exit_on_error = cli_exit_on_error

def error(self, message: str) -> NoReturn:
if not self._cli_exit_on_error:
raise SettingsError(f'error parsing CLI: {message}')
super().error(message)


T = TypeVar('T')
Expand All @@ -131,10 +143,6 @@ class EnvNoneType(str):
pass


class SettingsError(ValueError):
pass


class PydanticBaseSettingsSource(ABC):
"""
Abstract base class for settings sources, every settings source classes should inherit from it.
Expand Down Expand Up @@ -893,6 +901,8 @@ class CliSettingsSource(EnvSettingsSource, Generic[T]):
cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`.
cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions.
Defaults to `False`.
cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
Defaults to `True`.
cli_prefix: Prefix for command line arguments added under the root parser. Defaults to "".
case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`.
Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI
Expand All @@ -919,6 +929,7 @@ def __init__(
cli_avoid_json: bool | None = None,
cli_enforce_required: bool | None = None,
cli_use_class_docs_for_groups: bool | None = None,
cli_exit_on_error: bool | None = None,
cli_prefix: str | None = None,
case_sensitive: bool | None = True,
root_parser: Any = None,
Expand Down Expand Up @@ -953,6 +964,11 @@ def __init__(
if cli_use_class_docs_for_groups is not None
else settings_cls.model_config.get('cli_use_class_docs_for_groups', False)
)
self.cli_exit_on_error = (
cli_exit_on_error
if cli_exit_on_error is not None
else settings_cls.model_config.get('cli_exit_on_error', True)
)
self.cli_prefix = cli_prefix if cli_prefix is not None else settings_cls.model_config.get('cli_prefix', '')
if self.cli_prefix:
if cli_prefix.startswith('.') or cli_prefix.endswith('.') or not cli_prefix.replace('.', '').isidentifier(): # type: ignore
Expand All @@ -973,7 +989,9 @@ def __init__(
)

root_parser = (
_CliInternalArgParser(prog=self.cli_prog_name, description=settings_cls.__doc__)
_CliInternalArgParser(
cli_exit_on_error=self.cli_exit_on_error, prog=self.cli_prog_name, description=settings_cls.__doc__
)
if root_parser is None
else root_parser
)
Expand Down
Loading

0 comments on commit 4840d69

Please sign in to comment.