2022-12-28
🔗Torch on M1 GPU
It was a bit tricky getting rust-bert
to work on the M1 GPU. The issue, apparently, is that pytorch JIT models not trained for MPS (the macOS GPU framework) can not be directly loaded on MPS. But it does work to load them and then convert them to MPS.
But it turns out there's an easy solution. After loading the VarStore
on the CPU device, it's just a matter of var_store.set_device(tch::Device::Mps)
and then you're running on the GPU!
In my initial tests with an M1 Pro, this is about 2-3x as fast as running on CPU/AMX. This took the time to scan and index my Logseq database (~1000 documents) down to 6 seconds. Curious if this would have been 3 seconds on an M1 Max, but I didn't spend the extra $400 a couple years ago to find out now. :)
Switching Models
The MiniLmL12V2
model that I started out with is trained more for "sentence similarity" than for searching longer documents, and it shows. I switched the model to msmarco-bert-base-dot-v5
, which is supposed to work a lot better for semantic search, and indeed the search results improved immensely. The import process takes a lot longer (40 seconds for ~1000 documents), but that's still not bad. That GPU inference is pulling its weight.
These models aren't automatically supported by rust-bert
, but the instructions on how to download and use other models worked great, and this one is similar enough to the existing sentence embedding pipeline that I didn't have to change much.
Search Highlighting
Finally, I implemented search result highlighting, so that you get not only the title of the found document, but a snippet of text from the document relevant to the query. I'm now using two models in the program at once. The primary model used for the search is still a BERT-based model, and handles the full document encoding.
The BERT model is powerful but relatively slow, so for highlighting, I used the MiniLmL6V2
model, which is both much faster and focused on small strings of text.
Then for each matching document, I tokenize it, break the list of tokens into overlapping chunks, and encode each chunk with the model. Finally, I also encode the original query, take a dot product between the two results, and the highest dot product for each document is the best-matching chunk.
I think it could use some tweaking to pay more attention to actual word boundaries in the tokens. But overall I'm quite happy with this as a first effort in a few hours.