This repository was archived by the owner on Jun 26, 2024. It is now read-only.
File tree 5 files changed +41
-3
lines changed
5 files changed +41
-3
lines changed Original file line number Diff line number Diff line change 47
47
python --version
48
48
pip --version
49
49
python -m pip install --upgrade pip
50
- pip install flit
51
- flit install --deps all
52
50
pip list
53
51
shell : bash
54
52
Original file line number Diff line number Diff line change 1
1
"""Inference API for LLaMA"""
2
2
3
- from .api import LLaMAInference
3
+ from .model import LLaMAInference
4
4
5
5
__version__ = "0.0.0"
Original file line number Diff line number Diff line change
1
+ from typing import Any
2
+
3
+ import lightning as L
4
+ from lightning .app .components import PythonServer
5
+ from pydantic import BaseModel
6
+
7
+ from llama_inference .model import LLaMAInference
8
+
9
+
10
+ class PromptRequest (BaseModel ):
11
+ prompt : str
12
+
13
+
14
+ class Response (BaseModel ):
15
+ result : str
16
+
17
+
18
+ class ServeLLaMA (PythonServer ):
19
+ def setup (self , * args : Any , ** kwargs : Any ) -> None :
20
+ self ._model = LLaMAInference (* args , ** kwargs )
21
+
22
+ def predict (self , request : PromptRequest ) -> Any :
23
+ result = self ._model (request .prompt )
24
+ return Response (result = result )
25
+
26
+
27
+ if __name__ == "__main__" :
28
+ component = ServeLLaMA (input_type = PromptRequest , output_type = Response )
29
+ app = L .LightningApp (component )
Original file line number Diff line number Diff line change
1
+ import os
1
2
import sys
2
3
import time
3
4
from pathlib import Path
4
5
from typing import Optional
5
6
6
7
import lightning as L
7
8
import torch
9
+ from dotenv import load_dotenv
8
10
from lit_llama import LLaMA , Tokenizer
9
11
from lit_llama .utils import EmptyInitOnDevice
10
12
13
+ load_dotenv ()
14
+
15
+ WEIGHTS_PATH = os .environ .get ("WEIGHTS" )
16
+
11
17
12
18
@torch .no_grad ()
13
19
def _generate (
@@ -74,6 +80,10 @@ def __init__(
74
80
) -> None :
75
81
self .fabric = fabric = L .Fabric (accelerator = accelerator , devices = 1 )
76
82
83
+ if not checkpoint_path and WEIGHTS_PATH :
84
+ checkpoint_path = f"{ WEIGHTS_PATH } /{ model_size } /state_dict.pth"
85
+ tokenizer_path = f"{ WEIGHTS_PATH } /tokenizer.model"
86
+
77
87
if dtype is not None :
78
88
dt = getattr (torch , dtype , None )
79
89
if not isinstance (dt , torch .dtype ):
Original file line number Diff line number Diff line change
1
+ python-dotenv >= 1.0.0
You can’t perform that action at this time.
0 commit comments