-
Notifications
You must be signed in to change notification settings - Fork 942
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: davidfindlay1 <[email protected]>
- Loading branch information
1 parent
e3875a5
commit 7f9c58a
Showing
16 changed files
with
790 additions
and
278 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.1.0" | ||
__version__ = "1.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
import torch | ||
|
||
def split_einsum(q, k, v, mask, heads, dim_head): | ||
""" Attention Implementation backing AttentionImplementations.SPLIT_EINSUM | ||
- Implements https://machinelearning.apple.com/research/neural-engine-transformers | ||
- Recommended for ANE | ||
- Marginally slower on GPU | ||
""" | ||
mh_q = [ | ||
q[:, head_idx * dim_head:(head_idx + 1) * | ||
dim_head, :, :] for head_idx in range(heads) | ||
] # (bs, dim_head, 1, max_seq_length) * heads | ||
|
||
k = k.transpose(1, 3) | ||
mh_k = [ | ||
k[:, :, :, | ||
head_idx * dim_head:(head_idx + 1) * dim_head] | ||
for head_idx in range(heads) | ||
] # (bs, max_seq_length, 1, dim_head) * heads | ||
|
||
mh_v = [ | ||
v[:, head_idx * dim_head:(head_idx + 1) * | ||
dim_head, :, :] for head_idx in range(heads) | ||
] # (bs, dim_head, 1, max_seq_length) * heads | ||
|
||
attn_weights = [ | ||
torch.einsum("bchq,bkhc->bkhq", [qi, ki]) * (dim_head**-0.5) | ||
for qi, ki in zip(mh_q, mh_k) | ||
] # (bs, max_seq_length, 1, max_seq_length) * heads | ||
|
||
if mask is not None: | ||
for head_idx in range(heads): | ||
attn_weights[head_idx] = attn_weights[head_idx] + mask | ||
|
||
attn_weights = [ | ||
aw.softmax(dim=1) for aw in attn_weights | ||
] # (bs, max_seq_length, 1, max_seq_length) * heads | ||
attn = [ | ||
torch.einsum("bkhq,bchk->bchq", wi, vi) | ||
for wi, vi in zip(attn_weights, mh_v) | ||
] # (bs, dim_head, 1, max_seq_length) * heads | ||
|
||
attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length) | ||
return attn | ||
|
||
|
||
CHUNK_SIZE = 512 | ||
|
||
def split_einsum_v2(q, k, v, mask, heads, dim_head): | ||
""" Attention Implementation backing AttentionImplementations.SPLIT_EINSUM_V2 | ||
- Implements https://machinelearning.apple.com/research/neural-engine-transformers | ||
- Recommended for ANE | ||
- Marginally slower on GPU | ||
- Chunks the query sequence to avoid large intermediate tensors and improves ANE performance | ||
""" | ||
query_seq_length = q.size(3) | ||
num_chunks = query_seq_length // CHUNK_SIZE | ||
|
||
if num_chunks == 0: | ||
logger.info( | ||
"AttentionImplementations.SPLIT_EINSUM_V2: query sequence too short to chunk " | ||
f"({query_seq_length}<{CHUNK_SIZE}), fall back to AttentionImplementations.SPLIT_EINSUM (safe to ignore)") | ||
return split_einsum(q, k, v, mask, heads, dim_head) | ||
|
||
logger.info( | ||
"AttentionImplementations.SPLIT_EINSUM_V2: Splitting query sequence length of " | ||
f"{query_seq_length} into {num_chunks} chunks") | ||
|
||
mh_q = [ | ||
q[:, head_idx * dim_head:(head_idx + 1) * | ||
dim_head, :, :] for head_idx in range(heads) | ||
] # (bs, dim_head, 1, max_seq_length) * heads | ||
|
||
# Chunk the query sequence for each head | ||
mh_q_chunked = [ | ||
[h_q[..., chunk_idx * CHUNK_SIZE:(chunk_idx + 1) * CHUNK_SIZE] for chunk_idx in range(num_chunks)] | ||
for h_q in mh_q | ||
] # ((bs, dim_head, 1, QUERY_SEQ_CHUNK_SIZE) * num_chunks) * heads | ||
|
||
k = k.transpose(1, 3) | ||
mh_k = [ | ||
k[:, :, :, | ||
head_idx * dim_head:(head_idx + 1) * dim_head] | ||
for head_idx in range(heads) | ||
] # (bs, max_seq_length, 1, dim_head) * heads | ||
|
||
mh_v = [ | ||
v[:, head_idx * dim_head:(head_idx + 1) * | ||
dim_head, :, :] for head_idx in range(heads) | ||
] # (bs, dim_head, 1, max_seq_length) * heads | ||
|
||
attn_weights = [ | ||
[ | ||
torch.einsum("bchq,bkhc->bkhq", [qi_chunk, ki]) * (dim_head**-0.5) | ||
for qi_chunk in h_q_chunked | ||
] for h_q_chunked, ki in zip(mh_q_chunked, mh_k) | ||
] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads | ||
|
||
attn_weights = [ | ||
[aw_chunk.softmax(dim=1) for aw_chunk in aw_chunked] | ||
for aw_chunked in attn_weights | ||
] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads | ||
|
||
attn = [ | ||
[ | ||
torch.einsum("bkhq,bchk->bchq", wi_chunk, vi) | ||
for wi_chunk in wi_chunked | ||
] for wi_chunked, vi in zip(attn_weights, mh_v) | ||
] # ((bs, dim_head, 1, chunk_size) * num_chunks) * heads | ||
|
||
attn = torch.cat([ | ||
torch.cat(attn_chunked, dim=3) for attn_chunked in attn | ||
], dim=1) # (bs, dim, 1, max_seq_length) | ||
|
||
return attn | ||
|
||
|
||
def original(q, k, v, mask, heads, dim_head): | ||
""" Attention Implementation backing AttentionImplementations.ORIGINAL | ||
- Not recommended for ANE | ||
- Recommended for GPU | ||
""" | ||
bs = q.size(0) | ||
mh_q = q.view(bs, heads, dim_head, -1) | ||
mh_k = k.view(bs, heads, dim_head, -1) | ||
mh_v = v.view(bs, heads, dim_head, -1) | ||
|
||
attn_weights = torch.einsum("bhcq,bhck->bhqk", [mh_q, mh_k]) | ||
attn_weights.mul_(dim_head**-0.5) | ||
|
||
if mask is not None: | ||
attn_weights = attn_weights + mask | ||
|
||
attn_weights = attn_weights.softmax(dim=3) | ||
|
||
attn = torch.einsum("bhqk,bhck->bhcq", [attn_weights, mh_v]) | ||
attn = attn.contiguous().view(bs, heads * dim_head, 1, -1) | ||
return attn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from python_coreml_stable_diffusion.torch2coreml import _compile_coreml_model | ||
|
||
import argparse | ||
import coremltools as ct | ||
import numpy as np | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
|
||
# TODO: Read these values off of the NLContextualEmbedding API to enforce dimensions and track API versioning | ||
MAX_SEQUENCE_LENGTH = 256 | ||
EMBED_DIM = 512 | ||
BATCH_SIZE = 1 | ||
|
||
def main(args): | ||
# Layer that was trained to map NLContextualEmbedding to your text_encoder.hidden_size dimensionality | ||
text_encoder_projection = torch.jit.load(args.input_path) | ||
|
||
# Prepare random inputs for tracing the network before conversion | ||
random_input = torch.randn(BATCH_SIZE, MAX_SEQUENCE_LENGTH, EMBED_DIM) | ||
|
||
# Create a class to bake in the reshape operations required to fit the existing model interface | ||
class TextEncoderProjection(nn.Module): | ||
def __init__(self, proj): | ||
super().__init__() | ||
self.proj = proj | ||
|
||
def forward(self, x): | ||
return self.proj(x).transpose(1, 2).unsqueeze(2) # BSC, BC1S | ||
|
||
# Trace the torch model | ||
text_encoder_projection = torch.jit.trace(TextEncoderProjection(text_encoder_projection), (random_input,)) | ||
|
||
# Convert the model to Core ML | ||
mlpackage_path = os.path.join(args.output_dir, "MultilingualTextEncoderProjection.mlpackage") | ||
ct.convert( | ||
text_encoder_projection, | ||
inputs=[ct.TensorType('nlcontextualembeddings_output', shape=(1, MAX_SEQUENCE_LENGTH, EMBED_DIM), dtype=np.float32)], | ||
outputs=[ct.TensorType('encoder_hidden_states', dtype=np.float32)], | ||
minimum_deployment_target=ct.target.macOS14, # NLContextualEmbedding minimum availability build | ||
convert_to='mlprogram', | ||
).save() | ||
|
||
# Compile the model and save it under the specified directory | ||
_compile_coreml_model(mlpackage_path, args.output_dir, final_name="MultilingualTextEncoderProjection") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--input-path", | ||
help="Path to the torchscript file that contains the projection layer" | ||
) | ||
parser.add_argument( | ||
"--output-dir", | ||
help="Output directory in which the Core ML model should be saved", | ||
) | ||
args = parser.parse_args() | ||
|
||
main(args) |
Oops, something went wrong.