Skip to content
This repository was archived by the owner on Jun 26, 2024. It is now read-only.

Commit 33908f2

Browse files
authored
Deploy LLaMA with Lightning App (#1)
* add app * update * update * refactor
1 parent 8d4e95f commit 33908f2

File tree

5 files changed

+41
-3
lines changed

5 files changed

+41
-3
lines changed

.github/workflows/main.yml

-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ jobs:
4747
python --version
4848
pip --version
4949
python -m pip install --upgrade pip
50-
pip install flit
51-
flit install --deps all
5250
pip list
5351
shell: bash
5452

src/llama_inference/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Inference API for LLaMA"""
22

3-
from .api import LLaMAInference
3+
from .model import LLaMAInference
44

55
__version__ = "0.0.0"

src/llama_inference/app.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Any
2+
3+
import lightning as L
4+
from lightning.app.components import PythonServer
5+
from pydantic import BaseModel
6+
7+
from llama_inference.model import LLaMAInference
8+
9+
10+
class PromptRequest(BaseModel):
11+
prompt: str
12+
13+
14+
class Response(BaseModel):
15+
result: str
16+
17+
18+
class ServeLLaMA(PythonServer):
19+
def setup(self, *args: Any, **kwargs: Any) -> None:
20+
self._model = LLaMAInference(*args, **kwargs)
21+
22+
def predict(self, request: PromptRequest) -> Any:
23+
result = self._model(request.prompt)
24+
return Response(result=result)
25+
26+
27+
if __name__ == "__main__":
28+
component = ServeLLaMA(input_type=PromptRequest, output_type=Response)
29+
app = L.LightningApp(component)

src/llama_inference/api.py src/llama_inference/model.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
import os
12
import sys
23
import time
34
from pathlib import Path
45
from typing import Optional
56

67
import lightning as L
78
import torch
9+
from dotenv import load_dotenv
810
from lit_llama import LLaMA, Tokenizer
911
from lit_llama.utils import EmptyInitOnDevice
1012

13+
load_dotenv()
14+
15+
WEIGHTS_PATH = os.environ.get("WEIGHTS")
16+
1117

1218
@torch.no_grad()
1319
def _generate(
@@ -74,6 +80,10 @@ def __init__(
7480
) -> None:
7581
self.fabric = fabric = L.Fabric(accelerator=accelerator, devices=1)
7682

83+
if not checkpoint_path and WEIGHTS_PATH:
84+
checkpoint_path = f"{WEIGHTS_PATH}/{model_size}/state_dict.pth"
85+
tokenizer_path = f"{WEIGHTS_PATH}/tokenizer.model"
86+
7787
if dtype is not None:
7888
dt = getattr(torch, dtype, None)
7989
if not isinstance(dt, torch.dtype):

src/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python-dotenv>=1.0.0

0 commit comments

Comments
 (0)