Skip to content

Commit

Permalink
Code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed Aug 5, 2024
1 parent 9fac7af commit 8356a78
Show file tree
Hide file tree
Showing 6 changed files with 483 additions and 463 deletions.
15 changes: 7 additions & 8 deletions crates/numcodecs-python/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,20 @@ trait AnyCodec {

impl<T: Codec> AnyCodec for T {
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, PyErr> {
<T as Codec>::encode(&self, data).map_err(|err| PyRuntimeError::new_err(format!("{err}")))
<T as Codec>::encode(self, data).map_err(|err| PyRuntimeError::new_err(format!("{err}")))
}

fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, PyErr> {
<T as Codec>::decode(&self, encoded)
.map_err(|err| PyRuntimeError::new_err(format!("{err}")))
<T as Codec>::decode(self, encoded).map_err(|err| PyRuntimeError::new_err(format!("{err}")))
}

fn decode_into(&self, encoded: AnyArrayView, decoded: AnyArrayViewMut) -> Result<(), PyErr> {
<T as Codec>::decode_into(&self, encoded, decoded)
<T as Codec>::decode_into(self, encoded, decoded)
.map_err(|err| PyRuntimeError::new_err(format!("{err}")))
}

fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr> {
<T as Codec>::get_config(&self, Pythonizer::new(py))?.extract(py)
<T as Codec>::get_config(self, Pythonizer::new(py))?.extract(py)
}
}

Expand All @@ -114,7 +113,7 @@ impl<T: DynCodecType> AnyCodecType for T {
config: Bound<'py, PyDict>,
) -> Result<Box<dyn 'static + Send + Sync + AnyCodec>, PyErr> {
match <T as DynCodecType>::codec_from_config(
&self,
self,
&mut Depythonizer::from_object_bound(config.into_any()),
) {
Ok(codec) => Ok(Box::new(codec)),
Expand Down Expand Up @@ -158,10 +157,10 @@ impl RustCodec {
.ty
.codec_from_config(kwargs.unwrap_or_else(|| PyDict::new_bound(py)))?;

Ok(RustCodec {
codec,
Ok(Self {
cls_module,
cls_name,
codec,
})
}

Expand Down
2 changes: 2 additions & 0 deletions crates/numcodecs-python/src/pycodec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ impl Codec for PyCodec {
})
}

#[allow(clippy::too_many_lines)] // FIXME
fn decode_into(
&self,
encoded: AnyArrayView,
Expand Down Expand Up @@ -296,6 +297,7 @@ impl Codec for PyCodec {
.getattr(intern!(py, "asarray"))?
.call1((decoded_out,))?
.extract()?;
#[allow(clippy::unit_arg)]
if let Ok(d) = decoded_out.downcast::<PyArrayDyn<u8>>() {
if let AnyArrayBase::U8(mut decoded) = decoded {
return Ok(decoded.assign(&d.try_readonly()?.as_array()));
Expand Down
117 changes: 3 additions & 114 deletions crates/numcodecs-python/tests/crc32.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
use numcodecs::{
AnyArray, AnyArrayBase, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec,
DynCodecType, StaticCodec, StaticCodecType,
};
use numcodecs_python::{export_codec_class, CodecClassMethods, CodecMethods, PyCodec, Registry};
use numcodecs::{AnyArray, AnyArrayView, AnyCowArray, Codec, DynCodec, DynCodecType};
use numcodecs_python::{CodecClassMethods, CodecMethods, PyCodec, Registry};
use numpy::ndarray::{Array1, ArrayView1};
use pyo3::{
exceptions::{PyRuntimeError, PyTypeError},
prelude::*,
types::PyDict,
};
use serde::ser::SerializeMap;
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyDict};
use serde_json::json;
use ::{convert_case as _, pythonize as _, serde as _, serde_transcode as _};

Expand Down Expand Up @@ -124,106 +116,3 @@ fn rust_api() -> Result<(), PyErr> {

Ok(())
}

#[test]
fn export() -> Result<(), PyErr> {
#[derive(Copy, Clone)]
struct NegateCodec;

impl Codec for NegateCodec {
type Error = PyErr;

fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
AnyArrayBase::F64(a) => Ok(AnyArrayBase::F64(a.map(|x| -x))),
_ => Err(PyTypeError::new_err("negate only supports f64")),
}
}

fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
match encoded {
AnyArrayBase::F64(a) => Ok(AnyArrayBase::F64(a.map(|x| -x))),
_ => Err(PyTypeError::new_err("negate only supports f64")),
}
}

fn decode_into(
&self,
encoded: AnyArrayView,
decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
match (encoded, decoded) {
(AnyArrayBase::F64(e), AnyArrayBase::F64(mut d)) => {
d.assign(&e);
d.map_inplace(|x| *x = -(*x));
Ok(())
}
_ => Err(PyTypeError::new_err("negate only supports f64")),
}
}

fn get_config<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(None)?;
map.serialize_entry("id", Self::CODEC_ID)?;
map.end()
}
}

impl StaticCodec for NegateCodec {
const CODEC_ID: &'static str = "negate";

fn from_config<'de, D: serde::Deserializer<'de>>(_config: D) -> Result<Self, D::Error> {
Ok(Self)
}
}

Python::with_gil(|py| {
let module = PyModule::new_bound(py, "codecs")?;
export_codec_class(
py,
StaticCodecType::<NegateCodec>::of(),
module.as_borrowed(),
)?;

let config = PyDict::new_bound(py);
config.set_item("id", "negate")?;

// create a codec using registry lookup
let codec = Registry::get_codec(config.as_borrowed())?;
assert_eq!(codec.class().codec_id()?, "negate");

// check the codec's config
let config = codec.get_config()?;
assert_eq!(config.len(), 1);
assert_eq!(
config
.get_item("id")?
.map(|i| i.extract::<String>())
.transpose()?
.as_deref(),
Some("negate")
);

// encode and decode data with the codec
let data = &[1.0_f64, 2.0, 3.0, 4.0];
let encoded = codec.encode(
numpy::PyArray1::from_slice_bound(py, data)
.as_any()
.as_borrowed(),
)?;
let decoded = codec.decode(encoded.as_borrowed(), None)?;
// decode into an output
let decoded_out = numpy::PyArray1::<f64>::zeros_bound(py, (4,), false);
codec.decode(encoded.as_borrowed(), Some(decoded.as_any().as_borrowed()))?;

// check the encoded and decoded data
let encoded: Vec<f64> = encoded.extract()?;
let decoded: Vec<f64> = decoded.extract()?;
let decoded_out: Vec<f64> = decoded_out.extract()?;
assert_eq!(encoded, [-1.0, -2.0, -3.0, -4.0]);
assert_eq!(decoded, data);
assert_eq!(decoded_out, data);

Ok(())
})
}
111 changes: 111 additions & 0 deletions crates/numcodecs-python/tests/export.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use numcodecs::{
AnyArray, AnyArrayBase, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
StaticCodecType,
};
use numcodecs_python::{export_codec_class, CodecClassMethods, CodecMethods, Registry};
use pyo3::{exceptions::PyTypeError, prelude::*, types::PyDict};
use serde::ser::SerializeMap;
use ::{convert_case as _, pythonize as _, serde as _, serde_json as _, serde_transcode as _};

#[test]
fn export() -> Result<(), PyErr> {
Python::with_gil(|py| {
let module = PyModule::new_bound(py, "codecs")?;
export_codec_class(
py,
StaticCodecType::<NegateCodec>::of(),
module.as_borrowed(),
)?;

let config = PyDict::new_bound(py);
config.set_item("id", "negate")?;

// create a codec using registry lookup
let codec = Registry::get_codec(config.as_borrowed())?;
assert_eq!(codec.class().codec_id()?, "negate");

// check the codec's config
let config = codec.get_config()?;
assert_eq!(config.len(), 1);
assert_eq!(
config
.get_item("id")?
.map(|i| i.extract::<String>())
.transpose()?
.as_deref(),
Some("negate")
);

// encode and decode data with the codec
let data = &[1.0_f64, 2.0, 3.0, 4.0];
let encoded = codec.encode(
numpy::PyArray1::from_slice_bound(py, data)
.as_any()
.as_borrowed(),
)?;
let decoded = codec.decode(encoded.as_borrowed(), None)?;
// decode into an output
let decoded_out = numpy::PyArray1::<f64>::zeros_bound(py, (4,), false);
codec.decode(encoded.as_borrowed(), Some(decoded.as_any().as_borrowed()))?;

// check the encoded and decoded data
let encoded: Vec<f64> = encoded.extract()?;
let decoded: Vec<f64> = decoded.extract()?;
let decoded_out: Vec<f64> = decoded_out.extract()?;
assert_eq!(encoded, [-1.0, -2.0, -3.0, -4.0]);
assert_eq!(decoded, data);
assert_eq!(decoded_out, data);

Ok(())
})
}

#[derive(Copy, Clone)]
struct NegateCodec;

impl Codec for NegateCodec {
type Error = PyErr;

fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
AnyArrayBase::F64(a) => Ok(AnyArrayBase::F64(a.map(|x| -x))),
_ => Err(PyTypeError::new_err("negate only supports f64")),
}
}

fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
match encoded {
AnyArrayBase::F64(a) => Ok(AnyArrayBase::F64(a.map(|x| -x))),
_ => Err(PyTypeError::new_err("negate only supports f64")),
}
}

fn decode_into(
&self,
encoded: AnyArrayView,
decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
match (encoded, decoded) {
(AnyArrayBase::F64(e), AnyArrayBase::F64(mut d)) => {
d.assign(&e);
d.map_inplace(|x| *x = -(*x));
Ok(())
}
_ => Err(PyTypeError::new_err("negate only supports f64")),
}
}

fn get_config<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(None)?;
map.serialize_entry("id", Self::CODEC_ID)?;
map.end()
}
}

impl StaticCodec for NegateCodec {
const CODEC_ID: &'static str = "negate";

fn from_config<'de, D: serde::Deserializer<'de>>(_config: D) -> Result<Self, D::Error> {
Ok(Self)
}
}
Loading

0 comments on commit 8356a78

Please sign in to comment.