diff --git a/cases/cf.py b/cases/cf.py new file mode 100644 index 0000000..6bc5ac4 --- /dev/null +++ b/cases/cf.py @@ -0,0 +1,160 @@ +"""Couette flow case setup""" + +import jax.numpy as jnp +import numpy as np +from omegaconf import DictConfig + +from jax_sph.case_setup import SimulationSetup +from jax_sph.utils import Tag, pos_init_cartesian_2d, pos_init_cartesian_3d + + +class CF(SimulationSetup): + """Couette Flow. + + Setup based on "Modeling Low Reynolds Number Incompressible [...], Morris 1997, + and similar to PF case. + """ + + def __init__(self, cfg: DictConfig): + super().__init__(cfg) + + # custom variables related only to this Simulation + if self.case.dim == 2: + self.u_wall = jnp.array([self.special.u_x_wall, 0.0]) + elif self.case.dim == 3: + self.u_wall = jnp.array([self.special.u_x_wall, 0.0, 0.0]) + + # define offset vector + self.offset_vec = self._offset_vec() + + # relaxation configurations + if self.case.mode == "rlx": + self._set_default_rlx() + + if self.case.r0_type == "relaxed": + self._load_only_fluid = False + self._init_pos2D = self._get_relaxed_r0 + self._init_pos3D = self._get_relaxed_r0 + + def _box_size2D(self, n_walls): + dx2n = self.case.dx * n_walls * 2 + sp = self.special + return np.array([sp.L, sp.H + dx2n]) + + def _box_size3D(self, n_walls): + dx2n = self.case.dx * n_walls * 2 + sp = self.special + return np.array([sp.L, sp.H + dx2n, 0.4]) + + def _init_walls_2d(self, dx, n_walls): + sp = self.special + + # thickness of wall particles + dxn = dx * n_walls + + # horizontal and vertical blocks + horiz = pos_init_cartesian_2d(np.array([sp.L, dxn]), dx) + + # wall: bottom, top + wall_b = horiz.copy() + wall_t = horiz.copy() + np.array([0.0, sp.H + dxn]) + + rw = np.concatenate([wall_b, wall_t]) + return rw + + def _init_walls_3d(self, dx, n_walls): + sp = self.special + + # thickness of wall particles + dxn = dx * n_walls + + # horizontal and vertical blocks + horiz = pos_init_cartesian_3d(np.array([sp.L, dxn, 0.4]), dx) + + # wall: bottom, top + wall_b = horiz.copy() + wall_t = horiz.copy() + np.array([0.0, sp.H + dxn, 0.0]) + + rw = np.concatenate([wall_b, wall_t]) + return rw + + def _init_pos2D(self, box_size, dx, n_walls): + sp = self.special + + # initialize fluid phase + r_f = np.array([0.0, 1.0]) * n_walls * dx + pos_init_cartesian_2d( + np.array([sp.L, sp.H]), dx + ) + + # initialize walls + r_w = self._init_walls_2d(dx, n_walls) + + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) + + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + + # set velocity wall tag + box_size = self._box_size2D(n_walls) + mask_lid = r[:, 1] > (box_size[1] - n_walls * self.case.dx) + tag = jnp.where(mask_lid, Tag.MOVING_WALL, tag) + return r, tag + + def _init_pos3D(self, box_size, dx, n_walls): + sp = self.special + + # initialize fluid phase + r_f = np.array([0.0, 1.0, 0.0]) * n_walls * dx + pos_init_cartesian_3d( + np.array([sp.L, sp.H, 0.4]), dx + ) + + # initialize walls + r_w = self._init_walls_3d(dx, n_walls) + + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) + + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + + # set velocity wall tag + box_size = self._box_size3D(n_walls) + mask_lid = r[:, 1] > (box_size[1] - n_walls * self.case.dx) + tag = jnp.where(mask_lid, Tag.MOVING_WALL, tag) + return r, tag + + def _offset_vec(self): + dim = self.cfg.case.dim + if dim == 2: + res = np.array([0.0, 1.0]) * self.cfg.solver.n_walls * self.cfg.case.dx + elif dim == 3: + res = np.array([0.0, 1.0, 0.0]) * self.cfg.solver.n_walls * self.cfg.case.dx + return res + + def _init_velocity2D(self, r): + return jnp.zeros_like(r) + + def _init_velocity3D(self, r): + return jnp.zeros_like(r) + + def _external_acceleration_fn(self, r): + return jnp.zeros_like(r) + + def _boundary_conditions_fn(self, state): + mask1 = state["tag"][:, None] == Tag.SOLID_WALL + mask2 = state["tag"][:, None] == Tag.MOVING_WALL + + state["u"] = jnp.where(mask1, 0.0, state["u"]) + state["v"] = jnp.where(mask1, 0.0, state["v"]) + state["u"] = jnp.where(mask2, self.u_wall, state["u"]) + state["v"] = jnp.where(mask2, self.u_wall, state["v"]) + + state["dudt"] = jnp.where(mask1, 0.0, state["dudt"]) + state["dvdt"] = jnp.where(mask1, 0.0, state["dvdt"]) + state["dudt"] = jnp.where(mask2, 0.0, state["dudt"]) + state["dvdt"] = jnp.where(mask2, 0.0, state["dvdt"]) + + return state diff --git a/cases/cf.yaml b/cases/cf.yaml new file mode 100644 index 0000000..0ae8f0d --- /dev/null +++ b/cases/cf.yaml @@ -0,0 +1,24 @@ +extends: JAX_SPH_DEFAULTS + +seed: 123 + +case: + source: "cf.py" + dim: 2 + dx: 0.0166666 + viscosity: 100.0 + u_ref: 1.25 + special: + L: 0.4 # water column length + H: 1.0 # water column height + u_x_wall: 1.25 + +solver: + dt: 0.0000005 + t_end: 0.01 + is_bc_trick: True + +io: + write_type: ["h5"] + write_every: 200 + data_path: "data/debug" \ No newline at end of file diff --git a/tests/test_cf2d.py b/tests/test_cf2d.py new file mode 100644 index 0000000..4875c34 --- /dev/null +++ b/tests/test_cf2d.py @@ -0,0 +1,119 @@ +"""Test a full run of the solver on the Coette flow case from the validations.""" + +import os + +import jax.numpy as jnp +import numpy as np +import pytest +from jax import config +from omegaconf import OmegaConf + +from main import load_embedded_configs + + +def u_series_cf_exp(y, t, n_max=10): + """Analytical solution to unsteady Couette flow (low Re) + + Based on Series expansion as shown in: + "Modeling Low Reynolds Number Incompressible Flows Using SPH" + ba Morris et al. 1997 + """ + + eta = 100.0 # dynamic viscosity + rho = 1.0 # denstiy + nu = eta / rho # kinematic viscosity + u_max = 1.25 # max velocity in middle of channel + d = 1.0 # channel width + + Re = u_max * d / nu + print(f"Couette flow at Re={Re}") + + offset = u_max * y / d + + def term(n): + base = np.pi * n / d + + prefactor = 2 * u_max / (n * np.pi) * (-1) ** n + sin_term = np.sin(base * y) + exp_term = np.exp(-(base**2) * nu * t) + return prefactor * sin_term * exp_term + + res = offset + for i in range(1, n_max): + res += term(i) + + return res + + +@pytest.fixture +def setup_simulation(): + y_axis = np.linspace(0, 1, 21) + t_dimless = [0.0005, 0.001, 0.005] + # get analytical solution + ref_solutions = [] + for t_val in t_dimless: + ref_solutions.append(u_series_cf_exp(y_axis, t_val)) + return y_axis, t_dimless, ref_solutions + + +def run_simulation(tmp_path, tvf, solver): + """Emulate `main.py`.""" + data_path = tmp_path / f"cf_test_{tvf}" + + cli_args = OmegaConf.create( + { + "config": "cases/cf.yaml", + "case": {"dx": 0.0333333}, + "solver": {"name": solver, "tvf": tvf, "dt": 0.000002, "t_end": 0.005}, + "io": {"write_every": 250, "data_path": str(data_path)}, + } + ) + cfg = load_embedded_configs(cli_args) + + # Specify cuda device. These setting must be done before importing jax-md. + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow + os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu) + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cfg.xla_mem_fraction) + + if cfg.dtype == "float64": + config.update("jax_enable_x64", True) + + from jax_sph.simulate import simulate + + simulate(cfg) + + return data_path + + +def get_solution(data_path, t_dimless, y_axis): + from jax_sph.utils import sph_interpolator + + dir = os.listdir(data_path)[0] + cfg = OmegaConf.load(data_path / dir / "config.yaml") + step_max = np.array(np.rint(cfg.solver.t_end / cfg.solver.dt), dtype=int) + digits = len(str(step_max)) + + y_axis += 3 * cfg.case.dx + rs = 0.2 * jnp.ones([y_axis.shape[0], 2]) + rs = rs.at[:, 1].set(y_axis) + solutions = [] + for i in range(len(t_dimless)): + file_name = ( + "traj_" + str(int(t_dimless[i] / cfg.solver.dt)).zfill(digits) + ".h5" + ) + src_path = data_path / dir / file_name + interp_vel_fn = sph_interpolator(cfg, src_path) + solutions.append(interp_vel_fn(src_path, rs, prop="u", dim_ind=0)) + return solutions + + +@pytest.mark.parametrize( + "tvf, solver", [(0.0, "SPH"), (1.0, "SPH"), (0.0, "RIE"), (0.0, "DELTA")] +) +def test_cf2d(tvf, solver, tmp_path, setup_simulation): + """Test whether the couette flow simulation matches the analytical solution""" + y_axis, t_dimless, ref_solutions = setup_simulation + data_path = run_simulation(tmp_path, tvf, solver) + solutions = get_solution(data_path, t_dimless, y_axis) + for sol, ref_sol in zip(solutions, ref_solutions): + assert np.allclose(sol, ref_sol, atol=1e-2), "Velocity profile does not match." diff --git a/validation/cf2d.sh b/validation/cf2d.sh new file mode 100644 index 0000000..7104b9d --- /dev/null +++ b/validation/cf2d.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Validation of the 2D Couette Flow +# Reference result from: +# "Modeling Low Reynolds Number Incompressible Flows Using SPH", Morris 1997 + +# Generate data +python main.py config=cases/cf.yaml solver.tvf=1.0 io.data_path=data_valid/cf2d_tvf/ +python main.py config=cases/cf.yaml solver.tvf=0.0 io.data_path=data_valid/cf2d_notvf/ +python main.py config=cases/cf.yaml solver.tvf=0.0 solver.name=RIE solver.density_evolution=True io.data_path=data_valid/cf2d_Rie/ + +# Run validation script +python validation/validate.py --case=2D_CF --src_dir=data_valid/cf2d_tvf/ +python validation/validate.py --case=2D_CF --src_dir=data_valid/cf2d_notvf/ +python validation/validate.py --case=2D_CF --src_dir=data_valid/cf2d_Rie/ \ No newline at end of file diff --git a/validation/validate.py b/validation/validate.py index 1d6dd6d..3861dda 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -527,7 +527,7 @@ def val_2D_PF( ], save_fig=False, ): - def u_series_exp(y, t, n_max=10): + def u_series_pf_exp(y, t, n_max=10): """Analytical solution to unsteady Poiseuille flow (low Re) Based on Series expansion as shown in: @@ -573,7 +573,7 @@ def term(n): t_dimless = [0.0002, 0.001, 0.002, 0.01] for t_val, t_label in zip(t_dimless, t_axis): - plt.plot(y_axis, u_series_exp(y_axis, t_val), label=f"t={t_label}") + plt.plot(y_axis, u_series_pf_exp(y_axis, t_val), label=f"t={t_label}") # extract points from our solution dirs = os.listdir(val_dir_path) @@ -629,6 +629,115 @@ def term(n): plt.show() +def val_2D_CF( + val_dir_path, + dim=2, + nxs=[ + 60, + ], + save_fig=False, +): + def u_series_cf_exp(y, t, n_max=10): + """Analytical solution to unsteady Couette flow (low Re) + + Based on Series expansion as shown in: + "Modeling Low Reynolds Number Incompressible Flows Using SPH" + ba Morris et al. 1997 + """ + + eta = 100.0 # dynamic viscosity + rho = 1.0 # denstiy + nu = eta / rho # kinematic viscosity + u_max = 1.25 # max velocity in middle of channel + d = 1.0 # channel width + + Re = u_max * d / nu + print(f"Couette flow at Re={Re}") + + offset = u_max * y / d + + def term(n): + base = np.pi * n / d + + prefactor = 2 * u_max / (n * np.pi) * (-1) ** n + sin_term = np.sin(base * y) + exp_term = np.exp(-(base**2) * nu * t) + return prefactor * sin_term * exp_term + + res = offset + for i in range(1, n_max): + res += term(i) + + return res + + # analytical solution + + y_axis = np.linspace(0, 1, 100) + t_axis = [ + r"$0.02\times 10^{-2}$", + r"$0.10\times 10^{-2}$", + r"$0.20\times 10^{-2}$", + r"$1.00\times 10^{-2}$", + ] + t_dimless = [0.0002, 0.001, 0.002, 0.01] + + for t_val, t_label in zip(t_dimless, t_axis): + plt.plot(y_axis, u_series_cf_exp(y_axis, t_val), label=f"t={t_label}") + + # extract points from our solution + dirs = os.listdir(val_dir_path) + dirs = [d for d in dirs if os.path.isdir(os.path.join(val_dir_path, d))] + assert len(dirs) == 1, f"Expected only one directory in {val_dir_path}" + cfg = OmegaConf.load(os.path.join(val_dir_path, dirs[0], "config.yaml")) + dx, dt = cfg.case.dx, cfg.solver.dt + assert dt == 0.0000005 + assert dx == 0.0166666 + + num_points = 21 + dx_plot = 0.05 + y_axis = jnp.array([dx_plot * i for i in range(num_points)]) + 3 * dx + rs = 0.2 * jnp.ones([y_axis.shape[0], 2]) + rs = rs.at[:, 1].set(y_axis) + + step_max = np.array(np.rint(cfg.solver.t_end / dt), dtype=int) + digits = len(str(step_max)) + + for i, t_val in enumerate(t_dimless): + step = np.array(np.rint(t_val / dt), dtype=int) + file_name = "traj_" + str(step).zfill(digits) + ".h5" + src_path = os.path.join(val_dir_path, dirs[0], file_name) + + if i == 0: + interp_vel_fn = sph_interpolator(cfg, src_path) + + u_val = interp_vel_fn(src_path, rs, prop="u", dim_ind=0) + + if i == 0: + plt.plot(y_axis - 3 * dx, u_val, "ko", mfc="none", label=r"SPH, $r_c$=0.05") + else: + plt.plot(y_axis - 3 * dx, u_val, "ko", mfc="none") + + # plot layout + + plt.legend() + plt.ylim([0, 1.4]) + plt.xlim([0, 1]) + plt.xlabel(r"y [-]") + plt.ylabel(r"$u_x$ [-]") + # plt.title(f"{str(dim)}D Poiseuille Flow") + plt.grid() + plt.tight_layout() + + ###### save or visualize + + if save_fig: + os.makedirs(val_dir_path, exist_ok=True) + nxs_str = "_".join([str(i) for i in nxs]) + plt.savefig(f"{val_dir_path}/{str(dim)}D_CF_{nxs_str}_new.png") + + plt.show() + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--case", type=str, help="One of the above cases") @@ -641,6 +750,9 @@ def term(n): if args.case == "2D_PF": val_2D_PF(args.src_dir, 2, [60], True) + elif args.case == "2D_CF": + val_2D_CF(args.src_dir, 2, [60], True) + elif args.case == "2D_LDC": val_2D_LDC( args.src_dir_tvf, args.src_dir_notvf, args.src_dir_Rie, save_fig=True