|
1 | 1 | import argparse
|
2 | 2 | import os
|
| 3 | +import re |
3 | 4 | import shutil
|
4 | 5 | import subprocess
|
5 | 6 | import sys
|
|
13 | 14 | python = sys.executable
|
14 | 15 | default_command_live = True
|
15 | 16 | index_url = os.environ.get('INDEX_URL', "")
|
| 17 | +re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") |
16 | 18 |
|
17 | 19 | fooocus_name = 'Fooocus'
|
18 | 20 |
|
@@ -115,6 +117,43 @@ def run_pip(command, desc=None, live=default_command_live):
|
115 | 117 | return None
|
116 | 118 |
|
117 | 119 |
|
| 120 | + |
| 121 | +# This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository. |
| 122 | +def requirements_met(requirements_file): |
| 123 | + """ |
| 124 | + Does a simple parse of a requirements.txt file to determine if all rerqirements in it |
| 125 | + are already installed. Returns True if so, False if not installed or parsing fails. |
| 126 | + """ |
| 127 | + |
| 128 | + import importlib.metadata |
| 129 | + import packaging.version |
| 130 | + |
| 131 | + with open(requirements_file, "r", encoding="utf8") as file: |
| 132 | + for line in file: |
| 133 | + if line.strip() == "": |
| 134 | + continue |
| 135 | + |
| 136 | + m = re.match(re_requirement, line) |
| 137 | + if m is None: |
| 138 | + return False |
| 139 | + |
| 140 | + package = m.group(1).strip() |
| 141 | + version_required = (m.group(2) or "").strip() |
| 142 | + |
| 143 | + if version_required == "": |
| 144 | + continue |
| 145 | + |
| 146 | + try: |
| 147 | + version_installed = importlib.metadata.version(package) |
| 148 | + except Exception: |
| 149 | + return False |
| 150 | + |
| 151 | + if packaging.version.parse(version_required) != packaging.version.parse(version_installed): |
| 152 | + return False |
| 153 | + |
| 154 | + return True |
| 155 | + |
| 156 | + |
118 | 157 | def download_repositories():
|
119 | 158 | import pygit2
|
120 | 159 |
|
@@ -175,17 +214,20 @@ def download_models():
|
175 | 214 | )
|
176 | 215 |
|
177 | 216 |
|
178 |
| -def run_pip_install(): |
179 |
| - print("Run pip install") |
180 |
| - run_pip("install -r requirements.txt", "requirements") |
181 |
| - run_pip("install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118", "torch") |
182 |
| - run_pip("install xformers", "xformers") |
183 |
| - |
184 |
| - |
185 | 217 | def prepare_environments(args) -> bool:
|
| 218 | + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") |
| 219 | + |
186 | 220 | # Check if need pip install
|
| 221 | + requirements_file = 'requirements.txt' |
| 222 | + if not requirements_met(requirements_file): |
| 223 | + run_pip(f"install -r \"{requirements_file}\"", "requirements") |
| 224 | + |
| 225 | + if not is_installed("torch") or not is_installed("torchvision"): |
| 226 | + print(f"torch_index_url: {torch_index_url}") |
| 227 | + run_pip(f"install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}", "torch") |
| 228 | + |
187 | 229 | if not is_installed('xformers'):
|
188 |
| - run_pip_install() |
| 230 | + run_pip("install xformers==0.0.21", "xformers") |
189 | 231 |
|
190 | 232 | skip_sync_repo = False
|
191 | 233 | if args.sync_repo is not None:
|
|
0 commit comments