-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathREADME.Rmd
105 lines (72 loc) · 2.38 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
eval = F,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# huggr
<!-- badges: start -->
<!-- badges: end -->
Huggr can provides tools to use the [huggingface](https://huggingface.co/) api in R.
For now, it only allows to extract embeddings, but further capabilities such as mask filling, fine-tuning and classification will eventually be added.
## Installation
``` r
# install.packages("devtools")
devtools::install_github("benjaminguinaudeau/huggr")
```
## Example
### Set up python environment
```{bash, eval = F}
conda create -n mbert python=3.7
conda activate mbert
pip install torch
pip install transfomers
pip install sentencepiece
```
### Setting up reticulate with the right python environment
```{r}
reticulate::use_condaenv("mbert", required = TRUE)
options(python_init = TRUE)
library(reticulate)
library(huggr)
load_huggr_dep()
```
### Bert
```{r, warning = F}
py$bert_download(model_name = "bert-base-multilingual-cased",
path = "/data/hugg_dep/models/bert-base-multilingual-cased")
huggr_bert <- py$huggr_bert(path = "/data/hugg_dep/models/bert-base-multilingual-cased", gpu = T)
text <- c("Wie geht's dir?", "Die Übung macht den Meister", "Mir geht's gut")
huggr_bert$get_embedding(text) %>%
purrr::map_dfr(~tibble::as_tibble(t(.x))) %>%
.[,1:10] %>%
dplyr::glimpse()
```
### Roberta
```{r}
# py$roberta_download(model_name = "cardiffnlp/twitter-xlm-roberta-base",
# path = "/data/res/hugg_dep/models/twitter-xlm-roberta-base")
rob <- py$huggr_roberta("/data/res/hugg_dep/models/twitter-xlm-roberta-base")
text <- c("Wie geht's dir?", "Die Übung macht den Meister", "Mir geht's gut")
text %>%
roberta_clean() %>%
rob$get_embedding() %>%
purrr::map_dfr(~tibble::as_tibble(t(.x))) %>%
.[,1:10] %>%
dplyr::glimpse()
```
### T5
```{r, warning = F}
py$t5_download(model_name = "t5-small",
path = "/data/hugg_dep/models/t5-small")
huggr_t5 <- py$huggr_t5(path = "/data/hugg_dep/models/t5-small/", gpu = T)
text <- c("How are you?", "Canada is the best country for wood choping", "Penguins are the cutest living animals.")
huggr_t5$generate_text(task = "translate English to German:", text = text)
```