Skip to content

Commit

Permalink
WWDC23 update
Browse files Browse the repository at this point in the history
Co-authored-by: davidfindlay1 <[email protected]>
  • Loading branch information
atiorh and davidfindlay1 committed Jun 14, 2023
1 parent e3875a5 commit 7f9c58a
Show file tree
Hide file tree
Showing 16 changed files with 790 additions and 278 deletions.
281 changes: 153 additions & 128 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/float16_cpuandne_readmereel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/float16_gpu_readmereel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/palette6_cpuandne_readmereel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion python_coreml_stable_diffusion/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "1.0.0"
146 changes: 146 additions & 0 deletions python_coreml_stable_diffusion/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

import torch

def split_einsum(q, k, v, mask, heads, dim_head):
""" Attention Implementation backing AttentionImplementations.SPLIT_EINSUM
- Implements https://machinelearning.apple.com/research/neural-engine-transformers
- Recommended for ANE
- Marginally slower on GPU
"""
mh_q = [
q[:, head_idx * dim_head:(head_idx + 1) *
dim_head, :, :] for head_idx in range(heads)
] # (bs, dim_head, 1, max_seq_length) * heads

k = k.transpose(1, 3)
mh_k = [
k[:, :, :,
head_idx * dim_head:(head_idx + 1) * dim_head]
for head_idx in range(heads)
] # (bs, max_seq_length, 1, dim_head) * heads

mh_v = [
v[:, head_idx * dim_head:(head_idx + 1) *
dim_head, :, :] for head_idx in range(heads)
] # (bs, dim_head, 1, max_seq_length) * heads

attn_weights = [
torch.einsum("bchq,bkhc->bkhq", [qi, ki]) * (dim_head**-0.5)
for qi, ki in zip(mh_q, mh_k)
] # (bs, max_seq_length, 1, max_seq_length) * heads

if mask is not None:
for head_idx in range(heads):
attn_weights[head_idx] = attn_weights[head_idx] + mask

attn_weights = [
aw.softmax(dim=1) for aw in attn_weights
] # (bs, max_seq_length, 1, max_seq_length) * heads
attn = [
torch.einsum("bkhq,bchk->bchq", wi, vi)
for wi, vi in zip(attn_weights, mh_v)
] # (bs, dim_head, 1, max_seq_length) * heads

attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
return attn


CHUNK_SIZE = 512

def split_einsum_v2(q, k, v, mask, heads, dim_head):
""" Attention Implementation backing AttentionImplementations.SPLIT_EINSUM_V2
- Implements https://machinelearning.apple.com/research/neural-engine-transformers
- Recommended for ANE
- Marginally slower on GPU
- Chunks the query sequence to avoid large intermediate tensors and improves ANE performance
"""
query_seq_length = q.size(3)
num_chunks = query_seq_length // CHUNK_SIZE

if num_chunks == 0:
logger.info(
"AttentionImplementations.SPLIT_EINSUM_V2: query sequence too short to chunk "
f"({query_seq_length}<{CHUNK_SIZE}), fall back to AttentionImplementations.SPLIT_EINSUM (safe to ignore)")
return split_einsum(q, k, v, mask, heads, dim_head)

logger.info(
"AttentionImplementations.SPLIT_EINSUM_V2: Splitting query sequence length of "
f"{query_seq_length} into {num_chunks} chunks")

mh_q = [
q[:, head_idx * dim_head:(head_idx + 1) *
dim_head, :, :] for head_idx in range(heads)
] # (bs, dim_head, 1, max_seq_length) * heads

# Chunk the query sequence for each head
mh_q_chunked = [
[h_q[..., chunk_idx * CHUNK_SIZE:(chunk_idx + 1) * CHUNK_SIZE] for chunk_idx in range(num_chunks)]
for h_q in mh_q
] # ((bs, dim_head, 1, QUERY_SEQ_CHUNK_SIZE) * num_chunks) * heads

k = k.transpose(1, 3)
mh_k = [
k[:, :, :,
head_idx * dim_head:(head_idx + 1) * dim_head]
for head_idx in range(heads)
] # (bs, max_seq_length, 1, dim_head) * heads

mh_v = [
v[:, head_idx * dim_head:(head_idx + 1) *
dim_head, :, :] for head_idx in range(heads)
] # (bs, dim_head, 1, max_seq_length) * heads

attn_weights = [
[
torch.einsum("bchq,bkhc->bkhq", [qi_chunk, ki]) * (dim_head**-0.5)
for qi_chunk in h_q_chunked
] for h_q_chunked, ki in zip(mh_q_chunked, mh_k)
] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads

attn_weights = [
[aw_chunk.softmax(dim=1) for aw_chunk in aw_chunked]
for aw_chunked in attn_weights
] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads

attn = [
[
torch.einsum("bkhq,bchk->bchq", wi_chunk, vi)
for wi_chunk in wi_chunked
] for wi_chunked, vi in zip(attn_weights, mh_v)
] # ((bs, dim_head, 1, chunk_size) * num_chunks) * heads

attn = torch.cat([
torch.cat(attn_chunked, dim=3) for attn_chunked in attn
], dim=1) # (bs, dim, 1, max_seq_length)

return attn


def original(q, k, v, mask, heads, dim_head):
""" Attention Implementation backing AttentionImplementations.ORIGINAL
- Not recommended for ANE
- Recommended for GPU
"""
bs = q.size(0)
mh_q = q.view(bs, heads, dim_head, -1)
mh_k = k.view(bs, heads, dim_head, -1)
mh_v = v.view(bs, heads, dim_head, -1)

attn_weights = torch.einsum("bhcq,bhck->bhqk", [mh_q, mh_k])
attn_weights.mul_(dim_head**-0.5)

if mask is not None:
attn_weights = attn_weights + mask

attn_weights = attn_weights.softmax(dim=3)

attn = torch.einsum("bhqk,bhck->bhcq", [attn_weights, mh_v])
attn = attn.contiguous().view(bs, heads * dim_head, 1, -1)
return attn
60 changes: 60 additions & 0 deletions python_coreml_stable_diffusion/multilingual_projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from python_coreml_stable_diffusion.torch2coreml import _compile_coreml_model

import argparse
import coremltools as ct
import numpy as np
import os
import torch
import torch.nn as nn

# TODO: Read these values off of the NLContextualEmbedding API to enforce dimensions and track API versioning
MAX_SEQUENCE_LENGTH = 256
EMBED_DIM = 512
BATCH_SIZE = 1

def main(args):
# Layer that was trained to map NLContextualEmbedding to your text_encoder.hidden_size dimensionality
text_encoder_projection = torch.jit.load(args.input_path)

# Prepare random inputs for tracing the network before conversion
random_input = torch.randn(BATCH_SIZE, MAX_SEQUENCE_LENGTH, EMBED_DIM)

# Create a class to bake in the reshape operations required to fit the existing model interface
class TextEncoderProjection(nn.Module):
def __init__(self, proj):
super().__init__()
self.proj = proj

def forward(self, x):
return self.proj(x).transpose(1, 2).unsqueeze(2) # BSC, BC1S

# Trace the torch model
text_encoder_projection = torch.jit.trace(TextEncoderProjection(text_encoder_projection), (random_input,))

# Convert the model to Core ML
mlpackage_path = os.path.join(args.output_dir, "MultilingualTextEncoderProjection.mlpackage")
ct.convert(
text_encoder_projection,
inputs=[ct.TensorType('nlcontextualembeddings_output', shape=(1, MAX_SEQUENCE_LENGTH, EMBED_DIM), dtype=np.float32)],
outputs=[ct.TensorType('encoder_hidden_states', dtype=np.float32)],
minimum_deployment_target=ct.target.macOS14, # NLContextualEmbedding minimum availability build
convert_to='mlprogram',
).save()

# Compile the model and save it under the specified directory
_compile_coreml_model(mlpackage_path, args.output_dir, final_name="MultilingualTextEncoderProjection")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-path",
help="Path to the torchscript file that contains the projection layer"
)
parser.add_argument(
"--output-dir",
help="Output directory in which the Core ML model should be saved",
)
args = parser.parse_args()

main(args)
Loading

0 comments on commit 7f9c58a

Please sign in to comment.