Skip to content

Commit 86ed0d7

Browse files
ManonMarchandbsipocz
authored andcommitted
refactor: remove Column and Join from SimbadClass
1 parent 7c51ec9 commit 86ed0d7

File tree

3 files changed

+92
-93
lines changed

3 files changed

+92
-93
lines changed

astroquery/simbad/core.py

+59-60
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,23 @@ def _cached_query_tap(tap, query: str, *, maxrec=10000):
7575
return tap.search(query, maxrec=maxrec).to_table()
7676

7777

78+
@dataclass(frozen=True)
79+
class _Column:
80+
"""A class to define a column in a SIMBAD query."""
81+
table: str
82+
name: str
83+
alias: str = field(default=None)
84+
85+
86+
@dataclass(frozen=True)
87+
class _Join:
88+
"""A class to define a join between two tables."""
89+
table: str
90+
column_left: Any
91+
column_right: Any
92+
join_type: str = field(default="JOIN")
93+
94+
7895
class SimbadClass(BaseVOQuery):
7996
"""The class for querying the SIMBAD web service.
8097
@@ -84,30 +101,15 @@ class SimbadClass(BaseVOQuery):
84101
"""
85102
SIMBAD_URL = 'https://' + conf.server + '/simbad/sim-script'
86103

87-
@dataclass(frozen=True)
88-
class Column:
89-
"""A class to define a column in a SIMBAD query."""
90-
table: str
91-
name: str
92-
alias: str = field(default=None)
93-
94-
@dataclass(frozen=True)
95-
class Join:
96-
"""A class to define a join between two tables."""
97-
table: str
98-
column_left: Any
99-
column_right: Any
100-
join_type: str = field(default="JOIN")
101-
102104
def __init__(self, ROW_LIMIT=None):
103105
super().__init__()
104106
# to create the TAPService
105107
self._server = conf.server
106108
self._tap = None
107109
self._hardlimit = None
108110
# attributes to construct ADQL queries
109-
self._columns_in_output = None # a list of Simbad.Column
110-
self.joins = [] # a list of Simbad.Join
111+
self._columns_in_output = None # a list of _Column
112+
self.joins = [] # a list of _Join
111113
self.criteria = [] # a list of strings
112114
self.ROW_LIMIT = ROW_LIMIT
113115

@@ -165,7 +167,7 @@ def hardlimit(self):
165167

166168
@property
167169
def columns_in_output(self):
168-
"""A list of Simbad.Column.
170+
"""A list of _Column.
169171
170172
They will be included in the output of the following methods:
171173
@@ -178,7 +180,7 @@ def columns_in_output(self):
178180
179181
"""
180182
if self._columns_in_output is None:
181-
self._columns_in_output = [Simbad.Column("basic", item)
183+
self._columns_in_output = [_Column("basic", item)
182184
for item in conf.default_columns]
183185
return self._columns_in_output
184186

@@ -277,7 +279,7 @@ def _get_bundle_columns(self, bundle_name):
277279
278280
Returns
279281
-------
280-
list[Simbad.Column]
282+
list[simbad._Column]
281283
The list of columns corresponding to the selected bundle.
282284
"""
283285
basic_columns = set(map(str.casefold, set(self.list_columns("basic")["column_name"])))
@@ -287,10 +289,10 @@ def _get_bundle_columns(self, bundle_name):
287289

288290
if bundle_name in bundle_entries:
289291
bundle = bundle_entries[bundle_name]
290-
columns = [Simbad.Column("basic", column) for column in basic_columns
292+
columns = [_Column("basic", column) for column in basic_columns
291293
if column.startswith(bundle["tap_startswith"])]
292294
if "tap_column" in bundle:
293-
columns = [Simbad.Column("basic", column) for column in bundle["tap_column"]] + columns
295+
columns = [_Column("basic", column) for column in bundle["tap_column"]] + columns
294296
return columns
295297

296298
def _add_table_to_output(self, table):
@@ -308,7 +310,7 @@ def _add_table_to_output(self, table):
308310
table = table.casefold()
309311

310312
if table == "basic":
311-
self.columns_in_output.append(Simbad.Column(table, "*"))
313+
self.columns_in_output.append(_Column(table, "*"))
312314
return
313315

314316
linked_to_basic = self.list_linked_tables("basic")
@@ -329,10 +331,10 @@ def _add_table_to_output(self, table):
329331
alias = [f'"{table}.{column}"' if not column.startswith(table) else None for column in columns]
330332

331333
# modify the attributes here
332-
self.columns_in_output += [Simbad.Column(table, column, alias)
334+
self.columns_in_output += [_Column(table, column, alias)
333335
for column, alias in zip(columns, alias)]
334-
self.joins += [Simbad.Join(table, Simbad.Column("basic", link["target_column"]),
335-
Simbad.Column(table, link["from_column"]))]
336+
self.joins += [_Join(table, _Column("basic", link["target_column"]),
337+
_Column(table, link["from_column"]))]
336338

337339
def add_votable_fields(self, *args):
338340
"""Add columns to the output of a SIMBAD query.
@@ -360,8 +362,8 @@ def add_votable_fields(self, *args):
360362
>>> from astroquery.simbad import Simbad
361363
>>> simbad = Simbad()
362364
>>> simbad.add_votable_fields('sp_type', 'sp_qual', 'sp_bibcode') # doctest: +REMOTE_DATA
363-
>>> simbad.columns_in_output[0] # doctest: +REMOTE_DATA
364-
SimbadClass.Column(table='basic', name='main_id', alias=None)
365+
>>> simbad.get_votable_fields() # doctest: +REMOTE_DATA
366+
['basic.main_id', 'basic.ra', 'basic.dec', 'basic.coo_err_maj', 'basic.coo_err_min', ...
365367
"""
366368

367369
# the legacy way of adding fluxes is the only case-dependant option
@@ -375,9 +377,9 @@ def add_votable_fields(self, *args):
375377
flux_filter = re.findall(r"\((\w+)\)", arg)[0]
376378
if len(flux_filter) == 1 and flux_filter.islower():
377379
flux_filter = flux_filter + "_"
378-
self.joins.append(self.Join("allfluxes", self.Column("basic", "oid"),
379-
self.Column("allfluxes", "oidref")))
380-
self.columns_in_output.append(self.Column("allfluxes", flux_filter))
380+
self.joins.append(_Join("allfluxes", _Column("basic", "oid"),
381+
_Column("allfluxes", "oidref")))
382+
self.columns_in_output.append(_Column("allfluxes", flux_filter))
381383
args.remove(arg)
382384

383385
# casefold args
@@ -391,7 +393,7 @@ def add_votable_fields(self, *args):
391393
bundles = output_options[output_options["type"] == "bundle of basic columns"]["name"]
392394

393395
# Add columns from basic
394-
self.columns_in_output += [Simbad.Column("basic", column) for column in args if column in basic_columns]
396+
self.columns_in_output += [_Column("basic", column) for column in args if column in basic_columns]
395397

396398
# Add tables
397399
tables_to_add = [table for table in args if table in all_tables]
@@ -415,7 +417,7 @@ def add_votable_fields(self, *args):
415417
# some columns are still there but under a new name
416418
if field_type == "alias":
417419
tap_column = field_data["tap_column"]
418-
self.columns_in_output.append(Simbad.Column("basic", tap_column))
420+
self.columns_in_output.append(_Column("basic", tap_column))
419421
warning_message = (f"'{votable_field}' has been renamed '{tap_column}'. You'll see it "
420422
"appearing with its new name in the output table")
421423
warnings.warn(warning_message, DeprecationWarning, stacklevel=2)
@@ -462,7 +464,7 @@ def reset_votable_fields(self):
462464
- `query_criteria`.
463465
464466
"""
465-
self.columns_in_output = [Simbad.Column("basic", item)
467+
self.columns_in_output = [_Column("basic", item)
466468
for item in conf.default_columns]
467469
self.joins = []
468470
self.criteria = []
@@ -555,9 +557,9 @@ def query_object(self, object_name, *, wildcard=False,
555557
"""
556558
top, columns, joins, instance_criteria = self._get_query_parameters()
557559

558-
columns.append(Simbad.Column("ident", "id", "matched_id"))
560+
columns.append(_Column("ident", "id", "matched_id"))
559561

560-
joins.append(Simbad.Join("ident", Simbad.Column("basic", "oid"), Simbad.Column("ident", "oidref")))
562+
joins.append(_Join("ident", _Column("basic", "oid"), _Column("ident", "oidref")))
561563

562564
if wildcard:
563565
instance_criteria.append(rf" regexp(id, '{_wildcard_to_regexp(object_name)}') = 1")
@@ -626,10 +628,10 @@ def query_objects(self, object_names, *, wildcard=False, criteria=None,
626628
instance_criteria.append(f"({criteria})")
627629

628630
if wildcard:
629-
columns.append(Simbad.Column("ident", "id", "matched_id"))
630-
joins += [Simbad.Join("ident", Simbad.Column("basic", "oid"),
631-
Simbad.Column("ident", "oidref"))]
632-
list_criteria = [f"regexp(id, '{_wildcard_to_regexp(object_name)}') = 1" for object_name in object_names]
631+
columns.append(_Column("ident", "id", "matched_id"))
632+
joins += [_Join("ident", _Column("basic", "oid"), _Column("ident", "oidref"))]
633+
list_criteria = [f"regexp(id, '{_wildcard_to_regexp(object_name)}') = 1"
634+
for object_name in object_names]
633635
instance_criteria += [f'({" OR ".join(list_criteria)})']
634636

635637
return self._query(top, columns, joins, instance_criteria,
@@ -640,16 +642,15 @@ def query_objects(self, object_names, *, wildcard=False, criteria=None,
640642
upload = Table({"user_specified_id": object_names,
641643
"object_number_id": list(range(1, len(object_names) + 1))})
642644
upload_name = "TAP_UPLOAD.script_infos"
643-
columns.append(Simbad.Column(upload_name, "*"))
645+
columns.append(_Column(upload_name, "*"))
644646

645-
left_joins = [Simbad.Join("ident", Simbad.Column(upload_name, "user_specified_id"),
646-
Simbad.Column("ident", "id"), "LEFT JOIN"),
647-
Simbad.Join("basic", Simbad.Column("basic", "oid"),
648-
Simbad.Column("ident", "oidref"), "LEFT JOIN")]
647+
left_joins = [_Join("ident", _Column(upload_name, "user_specified_id"),
648+
_Column("ident", "id"), "LEFT JOIN"),
649+
_Join("basic", _Column("basic", "oid"),
650+
_Column("ident", "oidref"), "LEFT JOIN")]
649651
for join in joins:
650-
left_joins.append(Simbad.Join(join.table,
651-
join.column_left,
652-
join.column_right, "LEFT JOIN"))
652+
left_joins.append(_Join(join.table, join.column_left,
653+
join.column_right, "LEFT JOIN"))
653654
return self._query(top, columns, left_joins, instance_criteria,
654655
from_table=upload_name,
655656
get_query_payload=get_query_payload,
@@ -814,9 +815,9 @@ def query_catalog(self, catalog, *, criteria=None, get_query_payload=False,
814815
"""
815816
top, columns, joins, instance_criteria = self._get_query_parameters()
816817

817-
columns.append(Simbad.Column("ident", "id", "catalog_id"))
818+
columns.append(_Column("ident", "id", "catalog_id"))
818819

819-
joins += [Simbad.Join("ident", Simbad.Column("basic", "oid"), Simbad.Column("ident", "oidref"))]
820+
joins += [_Join("ident", _Column("basic", "oid"), _Column("ident", "oidref"))]
820821

821822
instance_criteria.append(fr"id LIKE '{catalog} %'")
822823
if criteria:
@@ -848,13 +849,11 @@ def query_bibobj(self, bibcode, *, criteria=None,
848849
"""
849850
top, columns, joins, instance_criteria = self._get_query_parameters()
850851

851-
joins += [Simbad.Join("has_ref", Simbad.Column("basic", "oid"),
852-
Simbad.Column("has_ref", "oidref")),
853-
Simbad.Join("ref", Simbad.Column("has_ref", "oidbibref"),
854-
Simbad.Column("ref", "oidbib"))]
852+
joins += [_Join("has_ref", _Column("basic", "oid"), _Column("has_ref", "oidref")),
853+
_Join("ref", _Column("has_ref", "oidbibref"), _Column("ref", "oidbib"))]
855854

856-
columns += [Simbad.Column("ref", "bibcode"),
857-
Simbad.Column("has_ref", "obj_freq")]
855+
columns += [_Column("ref", "bibcode"),
856+
_Column("has_ref", "obj_freq")]
858857

859858
instance_criteria.append(f"bibcode = '{_adql_parameter(bibcode)}'")
860859
if criteria:
@@ -1071,11 +1070,11 @@ def query_criteria(self, *args, get_query_payload=False, **kwargs):
10711070
added_criteria = f"({CriteriaTranslator.parse(' & '.join(list(list(args) + list_kwargs)))})"
10721071
instance_criteria.append(added_criteria)
10731072
if "otypes." in added_criteria:
1074-
joins.append(self.Join("otypes", self.Column("basic", "oid"),
1075-
self.Column("otypes", "oidref")))
1073+
joins.append(_Join("otypes", _Column("basic", "oid"),
1074+
_Column("otypes", "oidref")))
10761075
if "allfluxes." in added_criteria:
1077-
joins.append(self.Join("allfluxes", self.Column("basic", "oid"),
1078-
self.Column("allfluxes", "oidref")))
1076+
joins.append(_Join("allfluxes", _Column("basic", "oid"),
1077+
_Column("allfluxes", "oidref")))
10791078
return self._query(top, columns, joins, instance_criteria,
10801079
get_query_payload=get_query_payload)
10811080

0 commit comments

Comments
 (0)