|
19 | 19 | from contextlib import asynccontextmanager
|
20 | 20 |
|
21 | 21 | # Imports required by the service's model
|
22 |
| -# TODO: 1. ADD REQUIRED IMPORTS (ALSO IN THE REQUIREMENTS.TXT) |
23 | 22 | from outlines import generate, models
|
24 | 23 | import torch
|
25 | 24 | import json
|
|
32 | 31 | device = "mps"
|
33 | 32 | else:
|
34 | 33 | device = "cpu"
|
| 34 | +print(f"Using device: {device}") |
| 35 | +model = models.transformers("Qwen/Qwen2.5-3B-Instruct", device=device) |
| 36 | + |
35 | 37 |
|
36 | 38 | class MyService(Service):
|
37 | 39 | """
|
@@ -79,19 +81,19 @@ def __init__(self):
|
79 | 81 | docs_url="https://docs.swiss-ai-center.ch/reference/core-concepts/service/",
|
80 | 82 | )
|
81 | 83 | self._logger = get_logger(settings)
|
82 |
| - self._model = models.transformers("Qwen/Qwen2.5-3B-Instruct", device=device) |
| 84 | + self._model = model |
83 | 85 |
|
84 |
| - # TODO: 5. CHANGE THE PROCESS METHOD (CORE OF THE SERVICE) |
85 | 86 | def process(self, data):
|
86 | 87 | json_schema = data["format"].data.decode("utf-8")
|
87 | 88 | prompt = data["prompt"].data.decode("utf-8")
|
88 |
| - |
| 89 | + |
89 | 90 | # Use Outlines library to format LLM outputs
|
90 |
| - |
| 91 | + |
91 | 92 | generator = generate.json(self._model, json_schema)
|
92 | 93 | result = generator(prompt)
|
93 |
| - |
94 |
| - result = json.dumps(result) |
| 94 | + |
| 95 | + # json to bytes |
| 96 | + result = json.dumps(result).encode("utf-8") |
95 | 97 |
|
96 | 98 | # NOTE that the result must be a dictionary with the keys being the field names set in the data_out_fields
|
97 | 99 | return {
|
@@ -149,7 +151,6 @@ async def announce():
|
149 | 151 | await service_service.graceful_shutdown(my_service, engine_url)
|
150 | 152 |
|
151 | 153 |
|
152 |
| -# TODO: 6. CHANGE THE API DESCRIPTION AND SUMMARY |
153 | 154 | api_description = """
|
154 | 155 | Uses Outlines library to format LLM outputs.
|
155 | 156 | """
|
|
0 commit comments