Skip to content

Commit

Permalink
Merge pull request #860 from dimagi/sk/llm-provider-updates
Browse files Browse the repository at this point in the history
LLM Provider updates
  • Loading branch information
snopoke authored Nov 14, 2024
2 parents 36f4a10 + e6560ff commit a0165ca
Show file tree
Hide file tree
Showing 13 changed files with 207 additions and 102 deletions.
3 changes: 2 additions & 1 deletion apps/generics/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ class BaseTypeSelectFormView(views.View):
title = None
extra_context = None
model = None
template = "generic/type_select_form.html"

_object = None

def get(self, request, *args, **kwargs):
form = self.get_form()
return render(request, "generic/type_select_form.html", self.get_context_data(form))
return render(request, self.template, self.get_context_data(form))

def post(self, request, *args, **kwargs):
form = self.get_form(request.POST)
Expand Down
29 changes: 10 additions & 19 deletions apps/service_providers/forms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from django import forms
from django.core.validators import URLValidator
from django.urls import reverse
from django.utils.html import format_html
from django.utils.translation import gettext_lazy as _

from apps.files.forms import BaseFileFormSet
from apps.service_providers.models import LlmProvider, LlmProviderModel, LlmProviderTypes
from apps.service_providers.models import LlmProviderModel


class ProviderTypeConfigForm(forms.Form):
Expand Down Expand Up @@ -245,31 +243,24 @@ class LlmProviderModelForm(forms.ModelForm):
class Meta:
model = LlmProviderModel
fields = ("type", "name", "max_token_limit")
widgets = {
"type": forms.HiddenInput(),
}

def __init__(self, team, *args, **kwargs):
self.team = team
super().__init__(*args, **kwargs)
types = LlmProvider.objects.filter(team=team).values_list("type", flat=True).all()
self.fields["type"].choices = [choice for choice in LlmProviderTypes.choices if choice[0] in types]
if len(types) == 0:
url = reverse("service_providers:new", kwargs={"team_slug": team.slug, "provider_type": "llm"})
self.fields["type"].help_text = format_html(
_('You must create an <a class="link" href="{}">LLM provider</a> first'), url
)

def clean(self):
cleaned_data = super().clean()
name = cleaned_data.get("name")
max_token_limit = cleaned_data.get("max_token_limit")

if name and max_token_limit:
if (
LlmProviderModel.objects.filter(team=self.team, name=name, max_token_limit=max_token_limit)
.exclude(pk=self.instance.pk if self.instance else None)
.exists()
):
raise forms.ValidationError(
{"__all__": _("A model with this name and max token limit already exists for your team")}
)
if (
LlmProviderModel.objects.filter(team=self.team, name=name, max_token_limit=max_token_limit)
.exclude(pk=self.instance.pk if self.instance else None)
.exists()
):
raise forms.ValidationError(_("A model with this name and max token limit already exists for your team"))

return cleaned_data
6 changes: 3 additions & 3 deletions apps/service_providers/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
app_name = "service_providers"

urlpatterns = [
path("llm_provider_model/table/", views.LlmProviderModelTableView.as_view(), name="llm_provider_model_table"),
path("llm_provider_model/create/", views.LlmProviderModelView.as_view(), name="llm_provider_model_new"),
path("llm_provider_model/<int:pk>/", views.LlmProviderModelView.as_view(), name="llm_provider_model_edit"),
path("llm_provider_model/create/", views.create_llm_provider_model, name="llm_provider_model_new"),
path(
"llm_provider_model/<int:pk>/delete/",
views.delete_llm_provider_model,
name="llm_provider_model_delete",
),
path("llm/create/", views.LlmProviderView.as_view(), name="llm_new"),
path("llm/<int:pk>/", views.LlmProviderView.as_view(), name="llm_edit"),
path("<slug:provider_type>/table/", views.ServiceProviderTableView.as_view(), name="table"),
path("<slug:provider_type>/create/", views.CreateServiceProvider.as_view(), name="new"),
path("<slug:provider_type>/<int:pk>/", views.CreateServiceProvider.as_view(), name="edit"),
Expand Down
108 changes: 45 additions & 63 deletions apps/service_providers/views.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from collections import defaultdict

from django.conf import settings
from django.contrib.auth.decorators import login_required, permission_required
from django.contrib.auth.mixins import PermissionRequiredMixin
from django.core.exceptions import ValidationError
from django.db import transaction
from django.http import HttpResponse, HttpResponseBadRequest
from django.shortcuts import get_object_or_404, resolve_url
from django.shortcuts import get_object_or_404, render, resolve_url
from django.urls import reverse
from django.views.decorators.http import require_http_methods
from django.views.generic import TemplateView
from django.views.generic.edit import ModelFormMixin, SingleObjectMixin
from django.views.decorators.http import require_http_methods, require_POST
from django_tables2 import SingleTableView
from waffle import flag_is_active

from apps.files.views import BaseAddFileHtmxView
from apps.service_providers.forms import LlmProviderModelForm
from apps.service_providers.models import LlmProviderModel, MessagingProviderType, VoiceProviderType
from apps.service_providers.tables import LlmProviderModelTable

from ..generics.views import BaseTypeSelectFormView
from ..teams.decorators import login_and_team_required
from .utils import ServiceProvider, get_service_provider_config_form


Expand Down Expand Up @@ -117,71 +116,54 @@ def get_success_url(self):
return resolve_url("single_team:manage_team", team_slug=self.request.team.slug)


class LlmProviderModelTableView(PermissionRequiredMixin, SingleTableView):
permission_required = "service_providers.view_llmprovidermodel"
paginate_by = 25
template_name = "table/single_table.html"
model = LlmProviderModel
table_class = LlmProviderModelTable

def get_queryset(self):
return LlmProviderModel.objects.filter(team=self.request.team)


class LlmProviderModelView(PermissionRequiredMixin, ModelFormMixin, SingleObjectMixin, TemplateView):
permission_required = ("service_providers.add_llmprovidermodel", "service_providers.change_llmprovidermodel")
model = LlmProviderModel
form_class = LlmProviderModelForm
template_name = "generic/object_form.html"
class LlmProviderView(CreateServiceProvider):
template = "service_providers/llm_provider_form.html"

def get_form_kwargs(self):
return {"team": self.request.team, **super().get_form_kwargs()}
@property
def provider_type(self) -> ServiceProvider:
return ServiceProvider.llm

@property
def extra_context(self):
default_models_by_type = _get_models_by_type(LlmProviderModel.objects.filter(team=None))
custom_models_type_type = _get_models_by_type(LlmProviderModel.objects.filter(team=self.request.team))
return {
"title": self._get_title(),
"button_text": "Save",
"active_tab": "manage-team",
"title": self.provider_type.label,
"default_models_by_type": default_models_by_type,
"custom_models_by_type": custom_models_type_type,
"new_model_form": LlmProviderModelForm(self.request.team),
}

def _get_title(self):
if self.object:
return "Edit Custom LLM Model"
return "Create Custom LLM Model"

def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context.update(self.extra_context)
return context

def get(self, request, *args, **kwargs):
if "pk" in self.kwargs:
self.object = self.get_object()
else:
self.object = None
return self.render_to_response(self.get_context_data(form=self.get_form()))

def post(self, request, *args, **kwargs):
if "pk" in self.kwargs:
self.object = self.get_object()
else:
self.object = None
form = self.get_form()
if form.is_valid():
return self.form_valid(form)
else:
return self.form_invalid(form)

def get_success_url(self):
return resolve_url("single_team:manage_team", team_slug=self.request.team.slug)

def form_valid(self, form):
if not self.object:
form.instance.team = self.request.team
return super().form_valid(form)

def get_queryset(self):
return LlmProviderModel.objects.filter(team=self.request.team)
def _get_models_by_type(queryset):
models_by_type = defaultdict(list)
for model in queryset:
models_by_type[model.type].append(model)
return {key: sorted(value, key=lambda x: x.name) for key, value in models_by_type.items()}


@require_POST
@login_and_team_required
@permission_required("service_providers.add_llmprovidermodel")
def create_llm_provider_model(request, team_slug: str):
form = LlmProviderModelForm(request.team, request.POST)
if form.is_valid():
model = form.save(commit=False)
model.team = request.team
model.save()
else:
if len(form.errors) == 1 and "__all__" in form.errors:
return HttpResponseBadRequest(", ".join([str(v) for v in form.errors.values()]))
return HttpResponseBadRequest(str(form.errors))
return render(
request,
"service_providers/components/custom_llm_models.html",
{
"models_by_type": _get_models_by_type(LlmProviderModel.objects.filter(team=request.team)),
"for_type": form.cleaned_data["type"],
},
)


@require_http_methods(["DELETE"])
Expand Down
13 changes: 13 additions & 0 deletions apps/web/templatetags/default_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from django import template

register = template.Library()


@register.simple_tag
def define(val=None):
return val


@register.filter(name="times")
def times(number):
return range(number)
3 changes: 3 additions & 0 deletions gpt_playground/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@
"apps.web.context_processors.google_analytics_id",
],
"loaders": _DEFAULT_LOADERS if DEBUG else _CACHED_LOADERS,
"builtins": [
"apps.web.templatetags.default_tags",
],
},
},
]
Expand Down
24 changes: 24 additions & 0 deletions templates/generic/help.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<div class="dropdown dropdown-right">
<div tabindex="0" role="button" class="btn btn-circle btn-ghost btn-xs text-info">
<svg
tabindex="0"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
class="h-4 w-4 stroke-current">
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"></path>
</svg>
</div>
<div
tabindex="0"
class="card compact dropdown-content bg-base-100 rounded-box z-[1] w-64 shadow">
<div tabindex="0" class="card-body">
{% if help_title %}<h2 class="card-title">{{ help_title }}</h2>{% endif %}
<p>{{ help_content }}</p>
</div>
</div>
</div>
8 changes: 6 additions & 2 deletions templates/generic/type_select_form.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
{% extends 'web/app/app_base.html' %}
{% load form_tags %}
{% block app %}
<div class="app-card max-w-5xl mx-auto">
<div class="app-card max-w-5xl mx-auto" x-data="{ type: '{{ secondary_key }}' }">
<h1 class="pg-title">{{ title }}</h1>
<div>
<form method="post" class="my-2" x-data="{ type: '{{ secondary_key }}' }" enctype="multipart/form-data">
<form method="post" class="my-2" enctype="multipart/form-data">
{% csrf_token %}
{% render_form_fields form.primary %}
{% block pre_secondary_form %}
{% endblock pre_secondary_form %}
{% for key, form in form.secondary.items %}
<div id="form_{{ key }}" x-show="type === '{{ key }}'" x-cloak>
{% if form.custom_template %}
Expand All @@ -30,6 +32,8 @@ <h1 class="pg-title">{{ title }}</h1>
{% endfor %}
<input type="submit" class="pg-button-primary mt-2" value="{{ button_text }}">
</form>
{% block post_form %}
{% endblock post_form %}
</div>
</div>
{% endblock %}
5 changes: 5 additions & 0 deletions templates/service_providers/components/custom_llm_models.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{% if not models_by_type %}
<p>No custom models</p>
{% else %}
{% include "service_providers/components/llm_models.html" with models_by_type=models_by_type show_delete=True %}
{% endif %}
34 changes: 34 additions & 0 deletions templates/service_providers/components/llm_models.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{% define 2 as num_cols %}
{% for type_, models in models_by_type.items %}
<div id="models_{{ type_ }}" x-show="type === '{{ type_ }}'" class="mt-2"
{% if not for_type or type_ != for_type %}style="display:none"{% endif %}
>
<div class="grid grid-cols-{{ num_cols }} gap-x-5">
{% for i in num_cols|times %}
<div class="grid grid-cols-3 font-semibold">
<div>Model Name</div>
<div class="place-self-center">Token Limit</div>
{% if show_delete %}
<div class="place-self-start">Delete</div>
{% endif %}
</div>
{% endfor %}
{% for model in models %}
<div class="grid grid-cols-3" id="model_{{ model.id }}">
<div>{{ model.name }}</div>
<div class="place-self-center">{{ model.max_token_limit }}</div>
{% if show_delete %}
<div class="place-self-start">
<button class="btn btn-xs btn-ghost"
hx-delete="{% url "service_providers:llm_provider_model_delete" request.team.slug model.id%}"
hx-target="#model_{{ model.id }}"
hx-swap="outerHTML">
<i class="fa-solid fa-trash"></i>
</button>
</div>
{% endif %}
</div>
{% endfor %}
</div>
</div>
{% endfor %}
Loading

0 comments on commit a0165ca

Please sign in to comment.