diff --git a/odo/backends/csv.py b/odo/backends/csv.py index cc1d26ec..714a0d7d 100644 --- a/odo/backends/csv.py +++ b/odo/backends/csv.py @@ -18,7 +18,7 @@ import datashape from datashape import discover, Record, Option -from datashape.predicates import isrecord +from datashape.predicates import isrecord, isdimension from datashape.dispatch import dispatch from ..compatibility import unicode, PY2 @@ -140,18 +140,26 @@ class CSV(object): If the csv file has a header or not encoding : str (default utf-8) File encoding + user_dshape: datashape or string representation + user specified datashape kwargs : other... Various choices about dialect """ canonical_extension = 'csv' def __init__(self, path, has_header=None, encoding='utf-8', - sniff_nbytes=10000, **kwargs): + sniff_nbytes=10000, user_dshape=None, **kwargs): self.path = path self._has_header = has_header self.encoding = encoding or 'utf-8' self._kwargs = kwargs self._sniff_nbytes = sniff_nbytes + if user_dshape: + if isinstance(user_dshape, (str, unicode)): + user_dshape = datashape.dshape(user_dshape) + if not isrecord(user_dshape.measure): + raise TypeError('Please provide a Record dshape for the csv') + self._dshape = user_dshape def _sniff_dialect(self, path): kwargs = self._kwargs @@ -330,6 +338,9 @@ def _(): @discover.register(CSV) def discover_csv(c, nrows=1000, **kwargs): + if c._dshape: + return c._dshape + df = csv_to_dataframe(c, nrows=nrows, **kwargs) df = coerce_datetimes(df) diff --git a/odo/backends/tests/test_csv.py b/odo/backends/tests/test_csv.py index 257d8cc2..eca49fd5 100644 --- a/odo/backends/tests/test_csv.py +++ b/odo/backends/tests/test_csv.py @@ -398,6 +398,17 @@ def test_discover_with_dotted_names(): assert dshape == datashape.dshape('var * {"a.b": int64, "c.d": int64}') assert dshape.measure.names == [u'a.b', u'c.d'] +def test_discover_csv_with_fixed_dshape(): + with filetext('name,val\nAlice,1\n,0\nBob,2') as fn: + ds = datashape.dshape('var * {name: string, val: float64}') + csv1 = CSV(fn, user_dshape=ds) + ds1 = discover(csv1) + assert ds1 == ds + csv2 = CSV(fn, has_header=True) + ds2 = discover(csv2) + assert ds1 != ds2 + + try: unichr