Skip to content

Commit 437e391

Browse files
committed
adds interface for building einsums
1 parent 6f81197 commit 437e391

File tree

13 files changed

+581
-36
lines changed

13 files changed

+581
-36
lines changed

.coveragerc

-28
This file was deleted.

.github/workflows/ci.yml

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
name: CI
2+
on:
3+
push:
4+
branches:
5+
- main
6+
pull_request:
7+
schedule:
8+
- cron: '17 3 * * 0'
9+
10+
jobs:
11+
flake8:
12+
name: Flake8
13+
runs-on: ubuntu-latest
14+
steps:
15+
- uses: actions/checkout@v2
16+
-
17+
uses: actions/setup-python@v1
18+
with:
19+
# matches compat target in setup.py
20+
python-version: '3.8'
21+
- name: "Main Script"
22+
run: |
23+
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh
24+
. ./prepare-and-run-flake8.sh "$(basename $GITHUB_REPOSITORY)" test examples
25+
26+
pylint:
27+
name: Pylint
28+
runs-on: ubuntu-latest
29+
steps:
30+
- uses: actions/checkout@v2
31+
-
32+
uses: actions/setup-python@v1
33+
with:
34+
python-version: '3.x'
35+
- name: "Main Script"
36+
run: |
37+
sed -i "s/loopy.git/loopy.git@kernel_callables_v3-edit2/g" requirements.txt
38+
curl -L -O -k https://tiker.net/ci-support-v0
39+
. ci-support-v0
40+
build_py_project_in_conda_env
41+
run_pylint "$(basename $GITHUB_REPOSITORY)" test/test_*.py examples
42+
43+
mypy:
44+
name: Mypy
45+
runs-on: ubuntu-latest
46+
steps:
47+
- uses: actions/checkout@v2
48+
-
49+
uses: actions/setup-python@v1
50+
with:
51+
python-version: '3.x'
52+
- name: "Main Script"
53+
run: |
54+
curl -L -O https://tiker.net/ci-support-v0
55+
. ./ci-support-v0
56+
build_py_project_in_conda_env
57+
python -m pip install mypy
58+
./run-mypy.sh
59+
60+
pytest:
61+
name: Conda Pytest
62+
runs-on: ubuntu-latest
63+
steps:
64+
- uses: actions/checkout@v2
65+
- name: "Main Script"
66+
run: |
67+
curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh
68+
. ./build-and-test-py-project-within-miniconda.sh
69+
70+
examples:
71+
name: Conda Examples
72+
runs-on: ubuntu-latest
73+
steps:
74+
- uses: actions/checkout@v2
75+
- name: "Main Script"
76+
run: |
77+
curl -L -O -k https://tiker.net/ci-support-v0
78+
. ci-support-v0
79+
build_py_project_in_conda_env
80+
run_examples
81+
82+
docs:
83+
name: Documentation
84+
runs-on: ubuntu-latest
85+
steps:
86+
- uses: actions/checkout@v2
87+
-
88+
uses: actions/setup-python@v1
89+
with:
90+
python-version: '3.x'
91+
- name: "Main Script"
92+
run: |
93+
./.ci-support/fix-code-for-docs.sh
94+
curl -L -O -k https://tiker.net/ci-support-v0
95+
. ci-support-v0
96+
build_py_project_in_conda_env
97+
build_docs

.run-pylint.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env python3
2+
"""This script allows using Pylint with YAML-based config files.
3+
4+
The usage of this script is identical to Pylint, except that this script accepts
5+
an additional argument "--yaml-rcfile=", which specifies a path to a YAML file
6+
from which command line options are derived. The "--yaml-rcfile=" argument may
7+
be given multiple times.
8+
9+
The YAML config file format is a list of "arg"/"val" entries. Multiple or
10+
omitted values are allowed. Repeated arguments are allowed. An example is as
11+
follows:
12+
13+
---
14+
- arg: errors-only
15+
- arg: ignore
16+
val:
17+
- dir1
18+
- dir2
19+
- arg: ignore
20+
val: dir3
21+
22+
This example is equivalent to invoking pylint with the options
23+
24+
pylint --errors-only --ignore=dir1,dir2 --ignore=dir3
25+
26+
"""
27+
28+
import sys
29+
import logging
30+
import shlex
31+
32+
import pylint.lint
33+
import yaml
34+
35+
logger = logging.getLogger(__name__)
36+
37+
38+
def generate_args_from_yaml(input_yaml):
39+
"""Generate a list of strings suitable for use as Pylint args, from YAML.
40+
41+
Arguments:
42+
input_yaml: YAML data, as an input file or bytes
43+
44+
"""
45+
46+
parsed_data = yaml.safe_load(input_yaml)
47+
48+
for entry in parsed_data:
49+
arg = entry["arg"]
50+
val = entry.get("val")
51+
52+
if val is not None:
53+
if isinstance(val, list):
54+
val = ",".join(str(item) for item in val)
55+
56+
yield "--%s=%s" % (arg, val)
57+
else:
58+
yield "--%s" % arg
59+
60+
61+
YAML_RCFILE_PREFIX = "--yaml-rcfile="
62+
63+
64+
def main():
65+
"""Process command line args and run Pylint."""
66+
args = []
67+
68+
for arg in sys.argv[1:]:
69+
if arg.startswith(YAML_RCFILE_PREFIX):
70+
config_path = arg[len(YAML_RCFILE_PREFIX):]
71+
with open(config_path, "r") as config_file:
72+
args.extend(generate_args_from_yaml(config_file))
73+
else:
74+
args.append(arg)
75+
76+
logger.info(" ".join(shlex.quote(arg) for arg in ["pylint"] + args))
77+
pylint.lint.Run(args)
78+
79+
80+
if __name__ == "__main__":
81+
logging.basicConfig(level=logging.INFO)
82+
main()

examples/build_einsum.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import feinsum as f
2+
3+
4+
print(f.einsum("ij,j->i",
5+
f.array((10, 4), "float32"),
6+
f.array((4, ), "float32"),
7+
))

run-mypy.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#! /bin/bash
2+
3+
mypy --show-error-codes --strict src/feinsum/

setup.cfg

+24-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ license_files = LICENSE.txt
1313
long_description = file: README.rst
1414
long_description_content_type = text/x-rst; charset=UTF-8
1515
url = https://github.com/kaushikcfd/feinsum/
16-
version = attr: feinsum.version.VERSION
16+
version = attr: VERSION.VERSION
1717
# Add here related links, for example:
1818
# project_urls =
1919
# Documentation = https://pyscaffold.org/
@@ -52,14 +52,15 @@ python_requires = >=3.8
5252
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
5353
# new major versions. This works if the required packages follow Semantic Versioning.
5454
# For more information, check out https://semver.org/.
55-
# install_requires =
56-
# importlib-metadata; python_version<"3.8"
55+
install_requires = importlib-metadata; python_version<"3.8"
56+
numpy>=1.20
57+
5758

5859

5960
[options.packages.find]
6061
where = src
6162
exclude =
62-
tests
63+
test
6364

6465
[options.extras_require]
6566
# Add here additional requirements for extra features, to install with:
@@ -90,7 +91,7 @@ norecursedirs =
9091
dist
9192
build
9293
.tox
93-
testpaths = tests
94+
testpaths = test
9495
# Use pytest markers to select/deselect specific tests
9596
# markers =
9697
# slow: mark tests as slow (deselect with '-m "not slow"')
@@ -115,3 +116,21 @@ exclude =
115116
dist
116117
.eggs
117118
docs/conf.py
119+
120+
[mypy]
121+
plugins = numpy.typing.mypy_plugin
122+
123+
[mypy-islpy]
124+
ignore_missing_imports = True
125+
126+
[mypy-loopy.*]
127+
ignore_missing_imports = True
128+
129+
[mypy-numpy]
130+
ignore_missing_imports = True
131+
132+
[mypy-pymbolic.*]
133+
ignore_missing_imports = True
134+
135+
[mypy-pyopencl]
136+
ignore_missing_imports = True

src/feinsum/version.py renamed to src/VERSION.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
1010
copies of the Software, and to permit persons to whom the Software is
1111
furnished to do so, subject to the following conditions:
12-
1312
The above copyright notice and this permission notice shall be included in
1413
all copies or substantial portions of the Software.
15-
1614
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1715
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1816
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -23,4 +21,4 @@
2321
"""
2422

2523

26-
VERSION = (2022, 0)
24+
VERSION = (2022, 1)

src/feinsum/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from feinsum.einsum import (Einsum,
2+
VeryLongAxis, EinsumAxisAccess,
3+
FreeAxis, SummationAxis)
4+
from feinsum.make_einsum import (einsum, Array, ArrayT, array)
5+
6+
7+
__all__ = (
8+
"Einsum", "VeryLongAxis", "EinsumAxisAccess", "FreeAxis",
9+
"SummationAxis",
10+
11+
"einsum", "Array", "ArrayT", "array",
12+
)

src/feinsum/einsum.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
import numpy as np
5+
6+
from typing import Union, Tuple, Any, FrozenSet
7+
from dataclasses import dataclass
8+
9+
10+
IntegralT = Union[int, np.int8, np.int16, np.int32, np.int64, np.uint8,
11+
np.uint16, np.uint32, np.uint64]
12+
INT_CLASSES = (int, np.int8, np.int16, np.int32, np.int64, np.uint8,
13+
np.uint16, np.uint32, np.uint64)
14+
15+
16+
ShapeComponentT = Union[IntegralT, "VeryLongAxis"]
17+
ShapeT = Tuple[ShapeComponentT, ...]
18+
19+
20+
class VeryLongAxis:
21+
"""
22+
Describes a shape which can be assumed to be very large.
23+
"""
24+
# TODO: Record the threshold over which an axis could be considered as
25+
# "VeryLong."
26+
27+
28+
@dataclass(frozen=True, repr=True, eq=True)
29+
class EinsumAxisAccess(abc.ABC):
30+
"""
31+
Base class for axis access types in an einsum expression.
32+
"""
33+
34+
35+
@dataclass(frozen=True, repr=True, eq=True)
36+
class FreeAxis(EinsumAxisAccess):
37+
"""
38+
Records the axis of an einsum argument over which contraction is not performed.
39+
40+
.. attribute:: output_index
41+
42+
Position of the corresponding index in the einsum's output.
43+
"""
44+
output_index: int
45+
46+
47+
@dataclass(frozen=True, repr=True, eq=True)
48+
class SummationAxis(EinsumAxisAccess):
49+
"""
50+
Records an index in an einsum expression over which reduction is performed.
51+
Sometimes also referred to as an axis with a corresponding "dummy index" in
52+
Ricci Calculus.
53+
54+
.. attribute:: index
55+
56+
An integer which is unique to a reduction index of an einsum.
57+
"""
58+
idx: int
59+
60+
61+
@dataclass(frozen=True, eq=True, repr=True)
62+
class Einsum:
63+
"""
64+
An einsum expression.
65+
"""
66+
arg_shapes: Tuple[ShapeT, ...]
67+
arg_dtypes: Tuple[np.dtype[Any], ...]
68+
access_descriptors: Tuple[Tuple[EinsumAxisAccess, ...], ...]
69+
use_matrix: Tuple[Tuple[FrozenSet[str], ...]]

src/feinsum/generator.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
Generates Loopy kernels on which transformations could be applied.
3+
"""
4+
5+
import loopy as lp
6+
7+
from feinsum.einsum import Einsum
8+
9+
10+
def generate_loopy_kernel(einsum_decr: Einsum) -> lp.TranslationUnit:
11+
...

0 commit comments

Comments
 (0)