|
| 1 | +import os |
1 | 2 | import uuid
|
| 3 | +import random |
| 4 | +import pathlib |
| 5 | +from string import Template |
2 | 6 | from pkg_resources import resource_filename
|
3 | 7 |
|
4 | 8 | from clumper import Clumper
|
| 9 | +from IPython.core.display import HTML |
5 | 10 | from bokeh.models import ColumnDataSource
|
6 | 11 | from bokeh.plotting import figure, show
|
7 | 12 | from bokeh.models import PolyDrawTool, PolyEditTool
|
@@ -217,3 +222,58 @@ def data(self):
|
217 | 222 | for k, v in self.poly_patches.items()
|
218 | 223 | },
|
219 | 224 | }
|
| 225 | + |
| 226 | + |
| 227 | +def _random_string(): |
| 228 | + """Generates a random HTML id for d3 charts.""" |
| 229 | + return "".join([random.choice("qwertyuiopasdfghjklzxcvbnm") for _ in range(6)]) |
| 230 | + |
| 231 | + |
| 232 | +def parallel_coordinates(dataf, label, height=200): |
| 233 | + """ |
| 234 | + Creates an interactive parallel coordinates chart to help with classification tasks. |
| 235 | +
|
| 236 | + Arguments: |
| 237 | + dataf: the dataframe to render |
| 238 | + label: the column that represents the label, will be used for coloring |
| 239 | + height: the height of the chart, in pixels |
| 240 | +
|
| 241 | + Usage: |
| 242 | +
|
| 243 | + ```python |
| 244 | + from hulearn.datasets import load_titanic |
| 245 | + from hulearn.experimental.interactive import parallel_coordinates |
| 246 | +
|
| 247 | + df = load_titanic(as_frame=True) |
| 248 | + parallel_coordinates(df, label="survived", height=200) |
| 249 | + ``` |
| 250 | + """ |
| 251 | + t = Template( |
| 252 | + pathlib.Path( |
| 253 | + resource_filename( |
| 254 | + "hulearn", os.path.join("static", "parcoords", "template.html") |
| 255 | + ) |
| 256 | + ).read_text() |
| 257 | + ) |
| 258 | + d3_blob_path = resource_filename( |
| 259 | + "hulearn", os.path.join("static", "parcoords", "d3.min.js") |
| 260 | + ) |
| 261 | + css_blob_path = resource_filename( |
| 262 | + "hulearn", os.path.join("static", "parcoords", "d3.parcoords.css") |
| 263 | + ) |
| 264 | + js_blob_path = resource_filename( |
| 265 | + "hulearn", os.path.join("static", "parcoords", "d3.parcoords.js") |
| 266 | + ) |
| 267 | + |
| 268 | + json_data = dataf.rename(columns={label: "label"}).to_json(orient="records") |
| 269 | + rendered = t.substitute( |
| 270 | + { |
| 271 | + "data": json_data, |
| 272 | + "id": _random_string(), |
| 273 | + "style": pathlib.Path(css_blob_path).read_text(), |
| 274 | + "d3_blob": pathlib.Path(d3_blob_path).read_text(), |
| 275 | + "parcoords_stuff": pathlib.Path(js_blob_path).read_text(), |
| 276 | + "height": f"{height}px", |
| 277 | + } |
| 278 | + ) |
| 279 | + return HTML(rendered) |
0 commit comments