-
Notifications
You must be signed in to change notification settings - Fork 664
/
Copy pathBPETokenizer.swift
183 lines (152 loc) · 6.09 KB
/
BPETokenizer.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
/// A tokenizer based on byte pair encoding.
@available(iOS 16.2, macOS 13.1, *)
public struct BPETokenizer {
/// A dictionary that maps pairs of tokens to the rank/order of the merge.
let merges: [TokenPair : Int]
/// A dictionary from of tokens to identifiers.
let vocabulary: [String: Int]
/// The start token.
let startToken: String = "<|startoftext|>"
/// The end token.
let endToken: String = "<|endoftext|>"
/// The token used for padding
let padToken: String = "<|endoftext|>"
/// The unknown token.
let unknownToken: String = "<|endoftext|>"
var unknownTokenID: Int {
vocabulary[unknownToken, default: 0]
}
/// Creates a tokenizer.
///
/// - Parameters:
/// - merges: A dictionary that maps pairs of tokens to the rank/order of the merge.
/// - vocabulary: A dictionary from of tokens to identifiers.
public init(merges: [TokenPair: Int], vocabulary: [String: Int]) {
self.merges = merges
self.vocabulary = vocabulary
}
/// Creates a tokenizer by loading merges and vocabulary from URLs.
///
/// - Parameters:
/// - mergesURL: The URL of a text file containing merges.
/// - vocabularyURL: The URL of a JSON file containing the vocabulary.
public init(mergesAt mergesURL: URL, vocabularyAt vocabularyURL: URL) throws {
self.merges = try Self.readMerges(url: mergesURL)
self.vocabulary = try! Self.readVocabulary(url: vocabularyURL)
}
/// Tokenizes an input string.
///
/// - Parameters:
/// - input: A string.
/// - minCount: The minimum number of tokens to return.
/// - Returns: An array of tokens and an array of token identifiers.
public func tokenize(input: String, minCount: Int? = nil) -> (tokens: [String], tokenIDs: [Int]) {
var tokens: [String] = []
tokens.append(startToken)
tokens.append(contentsOf: encode(input: input))
tokens.append(endToken)
// Pad if there was a min length specified
if let minLen = minCount, minLen > tokens.count {
tokens.append(contentsOf: repeatElement(padToken, count: minLen - tokens.count))
}
let ids = tokens.map({ vocabulary[$0, default: unknownTokenID] })
return (tokens: tokens, tokenIDs: ids)
}
/// Returns the token identifier for a token.
public func tokenID(for token: String) -> Int? {
vocabulary[token]
}
/// Returns the token for a token identifier.
public func token(id: Int) -> String? {
vocabulary.first(where: { $0.value == id })?.key
}
/// Decodes a sequence of tokens into a fully formed string
public func decode(tokens: [String]) -> String {
String(tokens.joined())
.replacingOccurrences(of: "</w>", with: " ")
.replacingOccurrences(of: startToken, with: "")
.replacingOccurrences(of: endToken, with: "")
}
/// Encode an input string to a sequence of tokens
func encode(input: String) -> [String] {
let normalized = input.trimmingCharacters(in: .whitespacesAndNewlines).lowercased()
let words = normalized.split(separator: " ")
return words.flatMap({ encode(word: $0) })
}
/// Encode a single word into a sequence of tokens
func encode(word: Substring) -> [String] {
var tokens = word.map { String($0) }
if let last = tokens.indices.last {
tokens[last] = tokens[last] + "</w>"
}
while true {
let pairs = pairs(for: tokens)
let canMerge = pairs.filter { merges[$0] != nil }
if canMerge.isEmpty {
break
}
// If multiple merges are found, use the one with the lowest rank
let shouldMerge = canMerge.min { merges[$0]! < merges[$1]! }!
tokens = update(tokens, merging: shouldMerge)
}
return tokens
}
/// Get the set of adjacent pairs / bigrams from a sequence of tokens
func pairs(for tokens: [String]) -> Set<TokenPair> {
guard tokens.count > 1 else {
return Set()
}
var pairs = Set<TokenPair>(minimumCapacity: tokens.count - 1)
var prev = tokens.first!
for current in tokens.dropFirst() {
pairs.insert(TokenPair(prev, current))
prev = current
}
return pairs
}
/// Update the sequence of tokens by greedily merging instance of a specific bigram
func update(_ tokens: [String], merging bigram: TokenPair) -> [String] {
guard tokens.count > 1 else {
return []
}
var newTokens = [String]()
newTokens.reserveCapacity(tokens.count - 1)
var index = 0
while index < tokens.count {
let remainingTokens = tokens[index...]
if let startMatchIndex = remainingTokens.firstIndex(of: bigram.first) {
// Found a possible match, append everything before it
newTokens.append(contentsOf: tokens[index..<startMatchIndex])
if index < tokens.count - 1 && tokens[startMatchIndex + 1] == bigram.second {
// Full match, merge
newTokens.append(bigram.first + bigram.second)
index = startMatchIndex + 2
} else {
// Only matched the first, no merge
newTokens.append(bigram.first)
index = startMatchIndex + 1
}
} else {
// Didn't find any more matches, append the rest unmerged
newTokens.append(contentsOf: remainingTokens)
break
}
}
return newTokens
}
}
@available(iOS 16.2, macOS 13.1, *)
extension BPETokenizer {
/// A hashable tuple of strings
public struct TokenPair: Hashable {
let first: String
let second: String
init(_ first: String, _ second: String) {
self.first = first
self.second = second
}
}
}