Skip to content

Commit

Permalink
Merge pull request #45 from hydroshare/44-upgrade-to-pydantic-v2
Browse files Browse the repository at this point in the history
[#44] upgrade to pydantic v2
  • Loading branch information
sblack-usu authored Feb 15, 2024
2 parents e041853 + baad19a commit 3e2295d
Show file tree
Hide file tree
Showing 29 changed files with 957 additions and 831 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
python-version: [3.9]

steps:
- uses: actions/checkout@v2
Expand Down
96 changes: 55 additions & 41 deletions hsmodels/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
from enum import Enum

from pydantic import AnyUrl, BaseModel
from pydantic import BaseModel
from pydantic_core import Url
from rdflib import Graph, Literal, URIRef

from hsmodels.namespaces import DC, HSTERMS, ORE, RDF, RDFS1
Expand Down Expand Up @@ -71,7 +72,7 @@ def load_rdf(rdf_str, file_format='xml'):
else:
rdf_metadata = _parse(schema, g)
if schema in user_schemas.keys():
return user_schemas[schema](**rdf_metadata.dict())
return user_schemas[schema](**rdf_metadata.model_dump(exclude_none=True))
return rdf_metadata
raise Exception("Could not find schema for \n{}".format(rdf_str))

Expand All @@ -84,7 +85,7 @@ def parse_file(schema, file, file_format='xml', subject=None):
def rdf_graph(schema):
for rdf_schema, user_schema in user_schemas.items():
if isinstance(schema, user_schema):
return _rdf_graph(rdf_schema(**schema.dict(to_rdf=True)), Graph())
return _rdf_graph(rdf_schema(**schema.model_dump(to_rdf=True)), Graph())
return _rdf_graph(schema, Graph())


Expand All @@ -93,23 +94,21 @@ def rdf_string(schema, rdf_format='pretty-xml'):


def _rdf_fields(schema):
for f in schema.__fields__.values():
if f.alias not in ['rdf_subject', 'rdf_type', 'label', 'dc_type']:
predicate = f.field_info.extra.get('rdf_predicate', None)
if not predicate:
config_field_info = schema.Config.fields.get(f.name, None)
if isinstance(config_field_info, dict):
predicate = config_field_info.get('rdf_predicate', None)
for fname, finfo in schema.model_fields.items():
if fname not in ['rdf_subject', 'rdf_type', 'label', 'dc_type']:
predicate = None
if finfo.json_schema_extra:
predicate = finfo.json_schema_extra.get('rdf_predicate', None)
if not predicate:
raise Exception(
"Schema configuration error for {}, all fields must specify a rdf_predicate".format(schema)
)
yield f, predicate
yield finfo, fname, predicate


def _rdf_graph(schema, graph=None):
for f, predicate in _rdf_fields(schema):
values = getattr(schema, f.name, None)
for f, fname, predicate in _rdf_fields(schema):
values = getattr(schema, fname, None)
if values is not None:
if not isinstance(values, list):
# handle single values as a list to simplify
Expand All @@ -122,8 +121,8 @@ def _rdf_graph(schema, graph=None):
graph = _rdf_graph(value, graph)
else:
# primitive value
if isinstance(value, AnyUrl):
value = URIRef(value)
if isinstance(value, Url):
value = URIRef(str(value))
elif isinstance(value, TermEnum):
value = URIRef(value.value)
elif isinstance(value, Enum):
Expand All @@ -146,18 +145,20 @@ def get_args(t):


def _parse(schema, metadata_graph, subject=None):
def nested_class(field):
if field.sub_fields:
clazz = get_args(field.outer_type_)[0]
else:
clazz = field.outer_type_
if inspect.isclass(clazz):
return issubclass(clazz, BaseModel)
return False
def get_nested_class(field):
origin = field.annotation
if origin:
if inspect.isclass(origin) and issubclass(origin, BaseModel):
return origin
if get_args(origin):
clazz = get_args(origin)[0]
if inspect.isclass(clazz) and issubclass(clazz, BaseModel):
return clazz
return None

def class_rdf_type(schema):
if schema.__fields__['rdf_type']:
return schema.__fields__['rdf_type'].default
if schema.model_fields['rdf_type']:
return schema.model_fields['rdf_type'].default
return None

if not subject:
Expand All @@ -168,34 +169,47 @@ def class_rdf_type(schema):
subject = metadata_graph.value(predicate=RDF.type, object=target_class)
if not subject:
raise Exception("Could not find subject for predicate=RDF.type, object={}".format(target_class))

kwargs = {}
for f, predicate in _rdf_fields(schema):
for f, name, predicate in _rdf_fields(schema):
parsed = []
for value in metadata_graph.objects(subject=subject, predicate=predicate):
if nested_class(f):
if f.sub_fields:
# list
clazz = f.sub_fields[0].outer_type_
else:
# single
clazz = f.outer_type_
parsed_class = _parse(clazz, metadata_graph, value)
nested_clazz = get_nested_class(f)
if nested_clazz:
parsed_class = _parse(nested_clazz, metadata_graph, value)
if parsed_class:
parsed.append(parsed_class)
elif f.sub_fields:
parsed.append([])
else:
parsed_value = str(value.toPython())
# primitive value
# not a nested class (primitive class and not a subclass of BaseModel)

origin = f.annotation
origin_clazz = getattr(origin, '__origin__', None)
parsed_value = None
if origin_clazz is list:
clazz = origin.__args__[0]
if issubclass(clazz, BaseModel):
parsed_class = _parse(clazz, metadata_graph, value)
if parsed_class:
parsed.append(parsed_class)
else:
# primitive value
parsed_value = str(value.toPython())
else:
# primitive value
parsed_value = str(value.toPython())

if parsed_value:
parsed.append(parsed_value)

if len(parsed) > 0:
if f.sub_fields:
origin = f.annotation
origin_clazz = getattr(origin, '__origin__', None)
if origin_clazz is list:
# list
kwargs[f.name] = parsed
kwargs[name] = parsed
else:
# single
kwargs[f.name] = parsed[0]
kwargs[name] = parsed[0]
if kwargs:
instance = schema(**kwargs, rdf_subject=subject)
return instance
Expand Down
Loading

0 comments on commit 3e2295d

Please sign in to comment.