Skip to content

Commit f4759a3

Browse files
authored
validating default value (#479)
* validating default value * internal ValueError on invalid default value * raise ValidationError on invalid default
1 parent cdeba70 commit f4759a3

File tree

12 files changed

+157
-46
lines changed

12 files changed

+157
-46
lines changed

.rustfmt.toml

-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
max_width = 120
2-
imports_granularity = "Module"

pydantic_core/core_schema.py

+6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class CoreConfig(TypedDict, total=False):
3232
# used on typed-dicts and tagged union keys
3333
from_attributes: bool
3434
revalidate_models: bool
35+
# whether to validate default values during validation, default False
36+
validate_default: bool
3537
# used on typed-dicts and arguments
3638
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
3739
# fields related to string fields only
@@ -2068,6 +2070,7 @@ class WithDefaultSchema(TypedDict, total=False):
20682070
default: Any
20692071
default_factory: Callable[[], Any]
20702072
on_error: Literal['raise', 'omit', 'default'] # default: 'raise'
2073+
validate_default: bool # default: False
20712074
strict: bool
20722075
ref: str
20732076
metadata: Any
@@ -2083,6 +2086,7 @@ def with_default_schema(
20832086
default: Any = Omitted,
20842087
default_factory: Callable[[], Any] | None = None,
20852088
on_error: Literal['raise', 'omit', 'default'] | None = None,
2089+
validate_default: bool | None = None,
20862090
strict: bool | None = None,
20872091
ref: str | None = None,
20882092
metadata: Any = None,
@@ -2107,6 +2111,7 @@ def with_default_schema(
21072111
default: The default value to use
21082112
default_factory: A function that returns the default value to use
21092113
on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default'
2114+
validate_default: Whether the default value should be validated
21102115
strict: Whether the underlying schema should be validated with strict mode
21112116
ref: See [TODO] for details
21122117
metadata: See [TODO] for details
@@ -2117,6 +2122,7 @@ def with_default_schema(
21172122
schema=schema,
21182123
default_factory=default_factory,
21192124
on_error=on_error,
2125+
validate_default=validate_default,
21202126
strict=strict,
21212127
ref=ref,
21222128
metadata=metadata,

src/serializers/shared.rs

+4
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug {
275275
fn retry_with_lax_check(&self) -> bool {
276276
false
277277
}
278+
279+
fn get_default(&self, _py: Python) -> PyResult<Option<PyObject>> {
280+
Ok(None)
281+
}
278282
}
279283

280284
pub(crate) struct PydanticSerializer<'py> {

src/serializers/type_serializers/typed_dict.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use crate::build_context::BuildContext;
1111
use crate::build_tools::{py_error_type, schema_or_config, SchemaDict};
1212
use crate::PydanticSerializationUnexpectedValue;
1313

14-
use super::with_default::get_default;
1514
use super::{
1615
infer_json_key, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, CombinedSerializer, Extra,
1716
PydanticSerializer, SchemaFilter, SerializeInfer, TypeSerializer,
@@ -145,8 +144,8 @@ impl TypedDictSerializer {
145144

146145
fn exclude_default(&self, value: &PyAny, extra: &Extra, field: &TypedDictField) -> PyResult<bool> {
147146
if extra.exclude_defaults {
148-
if let Some(default) = get_default(value.py(), &field.serializer)? {
149-
if value.eq(default.as_ref())? {
147+
if let Some(default) = field.serializer.get_default(value.py())? {
148+
if value.eq(default)? {
150149
return Ok(true);
151150
}
152151
}

src/serializers/type_serializers/with_default.rs

+4-11
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,6 @@ impl BuildSerializer for WithDefaultSerializer {
3434
}
3535
}
3636

37-
pub(super) fn get_default<'a>(
38-
py: Python<'a>,
39-
serializer: &'a CombinedSerializer,
40-
) -> PyResult<Option<Cow<'a, PyObject>>> {
41-
if let CombinedSerializer::WithDefault(serializer) = serializer {
42-
serializer.default.default_value(py)
43-
} else {
44-
Ok(None)
45-
}
46-
}
47-
4837
impl TypeSerializer for WithDefaultSerializer {
4938
fn to_python(
5039
&self,
@@ -79,4 +68,8 @@ impl TypeSerializer for WithDefaultSerializer {
7968
fn retry_with_lax_check(&self) -> bool {
8069
self.serializer.retry_with_lax_check()
8170
}
71+
72+
fn get_default(&self, py: Python) -> PyResult<Option<PyObject>> {
73+
self.default.default_value(py)
74+
}
8275
}

src/validators/arguments.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use crate::input::{GenericArguments, Input};
1010
use crate::lookup_key::LookupKey;
1111
use crate::recursion_guard::RecursionGuard;
1212

13-
use super::with_default::get_default;
1413
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1514

1615
#[derive(Debug, Clone)]
@@ -221,11 +220,11 @@ impl Validator for ArgumentsValidator {
221220
}
222221
}
223222
(None, None) => {
224-
if let Some(value) = get_default(py, &parameter.validator)? {
223+
if let Some(value) = parameter.validator.default_value(py, Some(parameter.name.as_str()), extra, slots, recursion_guard)? {
225224
if let Some(ref kwarg_key) = parameter.kwarg_key {
226-
output_kwargs.set_item(kwarg_key, value.as_ref())?;
225+
output_kwargs.set_item(kwarg_key, value)?;
227226
} else {
228-
output_args.push(value.as_ref().clone_ref(py));
227+
output_args.push(value);
229228
}
230229
} else if parameter.kwarg_key.is_some() {
231230
errors.push(ValLineError::new_with_loc(

src/validators/dataclass.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ use crate::validators::function::convert_err;
1414

1515
use super::arguments::{json_get, json_slice, py_get, py_slice};
1616
use super::model::{create_class, force_setattr};
17-
use super::with_default::get_default;
1817
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1918

2019
#[derive(Debug, Clone)]
@@ -221,8 +220,14 @@ impl Validator for DataclassArgsValidator {
221220
}
222221
// found neither, check if there is a default value, otherwise error
223222
(None, None) => {
224-
if let Some(value) = get_default(py, &field.validator)? {
225-
set_item!(field, value.as_ref().clone_ref(py));
223+
if let Some(value) = field.validator.default_value(
224+
py,
225+
Some(field.name.as_str()),
226+
&extra,
227+
slots,
228+
recursion_guard,
229+
)? {
230+
set_item!(field, value);
226231
} else {
227232
errors.push(ValLineError::new_with_loc(
228233
ErrorType::MissingKeywordArgument,

src/validators/mod.rs

+13-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use pyo3::{intern, PyTraverseError, PyVisit};
99

1010
use crate::build_context::BuildContext;
1111
use crate::build_tools::{py_err, py_error_type, SchemaDict, SchemaError};
12-
use crate::errors::{ValError, ValResult, ValidationError};
12+
use crate::errors::{LocItem, ValError, ValResult, ValidationError};
1313
use crate::input::Input;
1414
use crate::questions::{Answers, Question};
1515
use crate::recursion_guard::RecursionGuard;
@@ -602,6 +602,18 @@ pub trait Validator: Send + Sync + Clone + Debug {
602602
recursion_guard: &'s mut RecursionGuard,
603603
) -> ValResult<'data, PyObject>;
604604

605+
/// Get a default value, currently only used by `WithDefaultValidator`
606+
fn default_value<'s, 'data>(
607+
&'s self,
608+
_py: Python<'data>,
609+
_outer_loc: Option<impl Into<LocItem>>,
610+
_extra: &Extra,
611+
_slots: &'data [CombinedValidator],
612+
_recursion_guard: &'s mut RecursionGuard,
613+
) -> ValResult<'data, Option<PyObject>> {
614+
Ok(None)
615+
}
616+
605617
/// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator
606618
/// this is used in the error location in unions, and in the top level message in `ValidationError`
607619
fn get_name(&self) -> &str;

src/validators/tuple.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use crate::input::{GenericCollection, Input};
88
use crate::recursion_guard::RecursionGuard;
99

1010
use super::list::{get_items_schema, length_check};
11-
use super::with_default::get_default;
1211
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1312

1413
#[derive(Debug, Clone)]
@@ -155,8 +154,10 @@ impl Validator for TuplePositionalValidator {
155154
Err(err) => return Err(err),
156155
},
157156
None => {
158-
if let Some(value) = get_default(py, &validator)? {
159-
output.push(value.as_ref().clone_ref(py));
157+
if let Some(value) =
158+
validator.default_value(py, Some(index), extra, slots, recursion_guard)?
159+
{
160+
output.push(value);
160161
} else {
161162
errors.push(ValLineError::new_with_loc(ErrorType::Missing, input, index));
162163
}

src/validators/typed_dict.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ use crate::lookup_key::LookupKey;
1515
use crate::questions::Question;
1616
use crate::recursion_guard::RecursionGuard;
1717

18-
use super::with_default::get_default;
1918
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
2019

2120
#[derive(Debug, Clone)]
@@ -230,8 +229,8 @@ impl Validator for TypedDictValidator {
230229
Err(err) => return Err(err),
231230
}
232231
continue;
233-
} else if let Some(value) = get_default(py, &field.validator)? {
234-
output_dict.set_item(&field.name_py, value.as_ref())?;
232+
} else if let Some(value) = field.validator.default_value(py, Some(field.name.as_str()), &extra, slots, recursion_guard)? {
233+
output_dict.set_item(&field.name_py, value)?;
235234
} else if field.required {
236235
errors.push(ValLineError::new_with_loc(
237236
ErrorType::Missing,

src/validators/with_default.rs

+39-16
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
use std::borrow::Cow;
2-
31
use pyo3::intern;
42
use pyo3::prelude::*;
53
use pyo3::types::PyDict;
64

7-
use crate::build_tools::{py_err, SchemaDict};
8-
use crate::errors::{ValError, ValResult};
5+
use crate::build_tools::{py_err, schema_or_config_same, SchemaDict};
6+
use crate::errors::{LocItem, ValError, ValResult};
97
use crate::input::Input;
108
use crate::questions::Question;
119
use crate::recursion_guard::RecursionGuard;
@@ -33,10 +31,10 @@ impl DefaultType {
3331
}
3432
}
3533

36-
pub fn default_value(&self, py: Python) -> PyResult<Option<Cow<PyObject>>> {
34+
pub fn default_value(&self, py: Python) -> PyResult<Option<PyObject>> {
3735
match self {
38-
Self::Default(ref default) => Ok(Some(Cow::Borrowed(default))),
39-
Self::DefaultFactory(ref default_factory) => Ok(Some(Cow::Owned(default_factory.call0(py)?))),
36+
Self::Default(ref default) => Ok(Some(default.clone_ref(py))),
37+
Self::DefaultFactory(ref default_factory) => Ok(Some(default_factory.call0(py)?)),
4038
Self::None => Ok(None),
4139
}
4240
}
@@ -54,6 +52,7 @@ pub struct WithDefaultValidator {
5452
default: DefaultType,
5553
on_error: OnError,
5654
validator: Box<CombinedValidator>,
55+
validate_default: bool,
5756
name: String,
5857
}
5958

@@ -89,6 +88,7 @@ impl BuildValidator for WithDefaultValidator {
8988
default,
9089
on_error,
9190
validator,
91+
validate_default: schema_or_config_same(schema, config, intern!(py, "validate_default"))?.unwrap_or(false),
9292
name,
9393
}
9494
.into())
@@ -108,12 +108,43 @@ impl Validator for WithDefaultValidator {
108108
Ok(v) => Ok(v),
109109
Err(e) => match self.on_error {
110110
OnError::Raise => Err(e),
111-
OnError::Default => Ok(self.default.default_value(py)?.unwrap().as_ref().clone()),
111+
OnError::Default => Ok(self
112+
.default_value(py, None::<usize>, extra, slots, recursion_guard)?
113+
.unwrap()),
112114
OnError::Omit => Err(ValError::Omit),
113115
},
114116
}
115117
}
116118

119+
fn default_value<'s, 'data>(
120+
&'s self,
121+
py: Python<'data>,
122+
outer_loc: Option<impl Into<LocItem>>,
123+
extra: &Extra,
124+
slots: &'data [CombinedValidator],
125+
recursion_guard: &'s mut RecursionGuard,
126+
) -> ValResult<'data, Option<PyObject>> {
127+
match self.default.default_value(py)? {
128+
Some(dft) => {
129+
if self.validate_default {
130+
match self.validate(py, dft.into_ref(py), extra, slots, recursion_guard) {
131+
Ok(v) => Ok(Some(v)),
132+
Err(e) => {
133+
if let Some(outer_loc) = outer_loc {
134+
Err(e.with_outer_location(outer_loc.into()))
135+
} else {
136+
Err(e)
137+
}
138+
}
139+
}
140+
} else {
141+
Ok(Some(dft))
142+
}
143+
}
144+
None => Ok(None),
145+
}
146+
}
147+
117148
fn get_name(&self) -> &str {
118149
&self.name
119150
}
@@ -136,11 +167,3 @@ impl WithDefaultValidator {
136167
matches!(self.on_error, OnError::Omit)
137168
}
138169
}
139-
140-
pub fn get_default<'a>(py: Python<'a>, validator: &'a CombinedValidator) -> PyResult<Option<Cow<'a, PyObject>>> {
141-
if let CombinedValidator::WithDefault(validator) = validator {
142-
validator.default.default_value(py)
143-
} else {
144-
Ok(None)
145-
}
146-
}

0 commit comments

Comments
 (0)