Skip to content

Commit 07469fe

Browse files
committed
mistral not working
1 parent 1fb34ee commit 07469fe

File tree

2 files changed

+26
-43
lines changed

2 files changed

+26
-43
lines changed

app.py

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,51 @@
22
from scipy.stats import chi2_contingency, ttest_ind, pearsonr
33
import matplotlib.pyplot as plt
44
import seaborn as sns
5-
from transformers import pipeline
65
import streamlit as st
76
from fuzzywuzzy import process
8-
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
9-
7+
from transformers import AutoModelForCausalLM, AutoTokenizer
8+
import torch
109

1110
def read_dataset(file_path):
1211
return pd.read_csv(file_path)
1312

14-
def chi_square_test(data, col1, col2):
15-
contingency_table = pd.crosstab(data[col1], data[col2])
16-
chi2, p, dof, ex = chi2_contingency(contingency_table)
17-
return chi2, p
18-
19-
def t_test(data, col1, col2):
20-
t_stat, p_val = ttest_ind(data[col1], data[col2], nan_policy='omit')
21-
return t_stat, p_val
22-
23-
def pearson_corr(data, col1, col2):
24-
corr, p_val = pearsonr(data[col1], data[col2])
25-
return corr, p_val
26-
27-
def data_summary(data):
28-
return data.describe()
29-
30-
def plot_histogram(data, column):
31-
fig, ax = plt.subplots()
32-
sns.histplot(data[column], ax=ax)
33-
ax.set_title(f'Histogram of {column}')
34-
return fig
13+
# ... (keep all the other functions as they are) ...
3514

36-
def plot_scatter(data, col1, col2):
37-
fig, ax = plt.subplots()
38-
sns.scatterplot(x=data[col1], y=data[col2], ax=ax)
39-
ax.set_title(f'Scatter plot of {col1} vs {col2}')
40-
return fig
15+
# Load the Mistral 7B model and tokenizer
16+
model_name = "mistralai/Mistral-7B-v0.1"
17+
tokenizer = AutoTokenizer.from_pretrained(model_name)
18+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
4119

42-
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
43-
nlp_model = pipeline("question-answering", model=model_name, tokenizer=model_name)
20+
def generate_response(prompt):
21+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
22+
with torch.no_grad():
23+
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7)
24+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
25+
return response.strip()
4426

4527
def extract_columns(query, data):
4628
words = query.lower().split()
4729
possible_columns = [word for word in words if word in data.columns]
48-
if len(possible_columns) >= 2:
49-
return possible_columns[:2]
50-
return None
30+
return possible_columns
5131

5232
def process_query(query, data):
5333
try:
54-
# Use NLP model to understand the query
55-
result = nlp_model(question=query, context=', '.join(data.columns))
34+
# Use Mistral 7B to understand the query
35+
context = ', '.join(data.columns)
36+
prompt = f"Given the following columns in a dataset: {context}\n\nUser query: {query}\n\nWhat type of analysis should be performed and which columns should be used?"
37+
result = generate_response(prompt)
5638

5739
# Fuzzy match for test types
5840
test_types = ["chi-square", "t-test", "correlation", "pearson", "summary", "histogram", "scatter plot"]
59-
best_match, score = process.extractOne(result['answer'], test_types)
41+
best_match, score = process.extractOne(result, test_types)
6042

6143
if score < 60: # Adjust this threshold as needed
6244
return "I'm not sure what analysis you want to perform. Could you please rephrase your query?"
6345

6446
columns = extract_columns(query, data)
6547

66-
if best_match in ["chi-square", "t-test", "correlation", "pearson"] and columns:
67-
col1, col2 = columns
48+
if best_match in ["chi-square", "t-test", "correlation", "pearson"] and len(columns) >= 2:
49+
col1, col2 = columns[:2]
6850
if best_match == "chi-square":
6951
chi2, p = chi_square_test(data, col1, col2)
7052
return f"Chi-square test result between {col1} and {col2}: chi2={chi2}, p={p}"
@@ -76,12 +58,12 @@ def process_query(query, data):
7658
return f"Pearson correlation result between {col1} and {col2}: corr={corr}, p_val={p_val}"
7759
elif best_match == "summary":
7860
return data_summary(data).to_dict()
79-
elif best_match == "histogram" and columns:
61+
elif best_match == "histogram" and len(columns) >= 1:
8062
column = columns[0]
8163
fig = plot_histogram(data, column)
8264
return fig, f"Histogram for {column} plotted."
83-
elif best_match == "scatter plot" and columns:
84-
col1, col2 = columns
65+
elif best_match == "scatter plot" and len(columns) >= 2:
66+
col1, col2 = columns[:2]
8567
fig = plot_scatter(data, col1, col2)
8668
return fig, f"Scatter plot for {col1} and {col2} plotted."
8769
else:

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ matplotlib
66
seaborn
77
streamlit
88
transformers
9-
fais-cpu
9+
fais-cpu
10+
fuzzywuzzy

0 commit comments

Comments
 (0)