|
11 | 11 | import time
|
12 | 12 | import uuid
|
13 | 13 | import zipfile
|
| 14 | +from importlib.metadata import Distribution, distributions |
14 | 15 | from pathlib import Path
|
15 | 16 | from typing import Any, Union
|
16 | 17 |
|
|
25 | 26 | ZPath = Union[Path, zipfile.Path]
|
26 | 27 | TEMP_FOLDER = Path(__file__).parent.parent / "temp"
|
27 | 28 | COMFY_PACK_DIR = Path(__file__).parent.parent / "src" / "comfy_pack"
|
28 |
| -EXCLUDE_PACKAGES = ["bentoml", "onnxruntime"] # TODO: standardize this |
| 29 | +EXCLUDE_PACKAGES = ["bentoml", "onnxruntime", "conda"] # TODO: standardize this |
| 30 | + |
| 31 | + |
| 32 | +def _get_requirement_string(dist: Distribution) -> str: |
| 33 | + direct_url_text = dist.read_text("direct_url.json") |
| 34 | + pinned_str = f'{dist.metadata["Name"]}=={dist.version}' |
| 35 | + if not direct_url_text: |
| 36 | + return pinned_str |
| 37 | + direct_url = json.loads(direct_url_text) |
| 38 | + if url := direct_url.get("url"): |
| 39 | + if url.startswith("file://"): |
| 40 | + # we are not able to share local files |
| 41 | + return pinned_str |
| 42 | + if vcs_info := direct_url.get("vcs_info"): |
| 43 | + url = f"{vcs_info['vcs']}+{url}@{vcs_info['commit_id']}" |
| 44 | + if subdirectory := direct_url.get("subdirectory"): |
| 45 | + url += f"#subdirectory={subdirectory}" |
| 46 | + return f"{dist.metadata['Name']} @ {url}" |
| 47 | + else: |
| 48 | + return pinned_str |
29 | 49 |
|
30 | 50 |
|
31 | 51 | async def _write_requirements(path: ZPath, extras: list[str] | None = None) -> None:
|
32 | 52 | print("Package => Writing requirements.txt")
|
33 | 53 | with path.joinpath("requirements.txt").open("w") as f:
|
34 |
| - proc = await asyncio.subprocess.create_subprocess_exec( |
35 |
| - sys.executable, |
36 |
| - "-m", |
37 |
| - "pip", |
38 |
| - "list", |
39 |
| - "--format", |
40 |
| - "freeze", |
41 |
| - "--exclude-editable", |
42 |
| - *[f"--exclude={p}" for p in EXCLUDE_PACKAGES], |
43 |
| - stdout=subprocess.PIPE, |
44 |
| - ) |
45 |
| - stdout, _ = await proc.communicate() |
46 |
| - f.write(stdout.decode().rstrip("\n") + "\n") |
| 54 | + for dist in distributions(): |
| 55 | + f.write(_get_requirement_string(dist) + "\n") |
47 | 56 | if extras:
|
48 | 57 | f.write("\n".join(extras) + "\n")
|
49 | 58 |
|
|
0 commit comments