Skip to content

Commit

Permalink
Improve BPETokenizer.readMerges performance (#169)
Browse files Browse the repository at this point in the history
From 240ms to 19ms on M1 MacBook Pro

Co-authored-by: Alejandro Isaza <[email protected]>
  • Loading branch information
alejandro-isaza and alejandro-isaza authored Apr 27, 2023
1 parent fbef6c3 commit a7bf219
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions swift/StableDiffusion/tokenizer/BPETokenizer+Reading.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,42 @@ extension BPETokenizer {

/// Read merges.txt file at URL into a dictionary mapping bigrams to the line number/rank/priority
static func readMerges(url: URL) throws -> [TokenPair: Int] {
let content = try String(contentsOf: url)
let lines = content.split(separator: "\n")

let merges: [(TokenPair, Int)] = try lines.enumerated().compactMap { (index, line) in
if line.hasPrefix("#") {
return nil
}
let pair = line.split(separator: " ")
if pair.count != 2 {
throw FileReadError.invalidMergeFileLine(index+1)
let data = try Data(contentsOf: url)
var merges = [TokenPair: Int]()
var index = 0
var line = [UInt8]()
for byte in data {
if byte == UInt8(ascii: "\n") {
if let pair = try parseMergesLine(line, index: index) {
merges[pair] = index
}
line.removeAll(keepingCapacity: true)
index += 1
} else {
line.append(byte)
}
return (TokenPair(String(pair[0]), String(pair[1])),index)
}
return [TokenPair : Int](uniqueKeysWithValues: merges)

return merges
}

static func parseMergesLine(_ line: [UInt8], index: Int) throws -> TokenPair? {
if line.isEmpty || line.first == UInt8(ascii: "#") {
return nil
}
let pair = line.split(separator: UInt8(ascii: " "))
if pair.count != 2 {
throw FileReadError.invalidMergeFileLine(index + 1)
}
return TokenPair( String(bytes: pair[0]), String(bytes: pair[1]))
}
}

extension String {
init(bytes: some Collection<UInt8>) {
self.init(unsafeUninitializedCapacity: bytes.count) { pointer in
_ = pointer.initialize(fromContentsOf: bytes)
return bytes.count
}
}
}

0 comments on commit a7bf219

Please sign in to comment.