22from scipy .stats import chi2_contingency , ttest_ind , pearsonr
33import matplotlib .pyplot as plt
44import seaborn as sns
5- from transformers import pipeline
65import streamlit as st
76from fuzzywuzzy import process
8- from transformers import AutoModelForQuestionAnswering , AutoTokenizer , pipeline
9-
7+ from transformers import AutoModelForCausalLM , AutoTokenizer
8+ import torch
109
1110def read_dataset (file_path ):
1211 return pd .read_csv (file_path )
1312
14- def chi_square_test (data , col1 , col2 ):
15- contingency_table = pd .crosstab (data [col1 ], data [col2 ])
16- chi2 , p , dof , ex = chi2_contingency (contingency_table )
17- return chi2 , p
18-
19- def t_test (data , col1 , col2 ):
20- t_stat , p_val = ttest_ind (data [col1 ], data [col2 ], nan_policy = 'omit' )
21- return t_stat , p_val
22-
23- def pearson_corr (data , col1 , col2 ):
24- corr , p_val = pearsonr (data [col1 ], data [col2 ])
25- return corr , p_val
26-
27- def data_summary (data ):
28- return data .describe ()
29-
30- def plot_histogram (data , column ):
31- fig , ax = plt .subplots ()
32- sns .histplot (data [column ], ax = ax )
33- ax .set_title (f'Histogram of { column } ' )
34- return fig
13+ # ... (keep all the other functions as they are) ...
3514
36- def plot_scatter (data , col1 , col2 ):
37- fig , ax = plt .subplots ()
38- sns .scatterplot (x = data [col1 ], y = data [col2 ], ax = ax )
39- ax .set_title (f'Scatter plot of { col1 } vs { col2 } ' )
40- return fig
15+ # Load the Mistral 7B model and tokenizer
16+ model_name = "mistralai/Mistral-7B-v0.1"
17+ tokenizer = AutoTokenizer .from_pretrained (model_name )
18+ model = AutoModelForCausalLM .from_pretrained (model_name , torch_dtype = torch .float16 , device_map = "auto" )
4119
42- model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
43- nlp_model = pipeline ("question-answering" , model = model_name , tokenizer = model_name )
20+ def generate_response (prompt ):
21+ inputs = tokenizer (prompt , return_tensors = "pt" ).to (model .device )
22+ with torch .no_grad ():
23+ outputs = model .generate (** inputs , max_new_tokens = 100 , temperature = 0.7 )
24+ response = tokenizer .decode (outputs [0 ], skip_special_tokens = True )
25+ return response .strip ()
4426
4527def extract_columns (query , data ):
4628 words = query .lower ().split ()
4729 possible_columns = [word for word in words if word in data .columns ]
48- if len (possible_columns ) >= 2 :
49- return possible_columns [:2 ]
50- return None
30+ return possible_columns
5131
5232def process_query (query , data ):
5333 try :
54- # Use NLP model to understand the query
55- result = nlp_model (question = query , context = ', ' .join (data .columns ))
34+ # Use Mistral 7B to understand the query
35+ context = ', ' .join (data .columns )
36+ prompt = f"Given the following columns in a dataset: { context } \n \n User query: { query } \n \n What type of analysis should be performed and which columns should be used?"
37+ result = generate_response (prompt )
5638
5739 # Fuzzy match for test types
5840 test_types = ["chi-square" , "t-test" , "correlation" , "pearson" , "summary" , "histogram" , "scatter plot" ]
59- best_match , score = process .extractOne (result [ 'answer' ] , test_types )
41+ best_match , score = process .extractOne (result , test_types )
6042
6143 if score < 60 : # Adjust this threshold as needed
6244 return "I'm not sure what analysis you want to perform. Could you please rephrase your query?"
6345
6446 columns = extract_columns (query , data )
6547
66- if best_match in ["chi-square" , "t-test" , "correlation" , "pearson" ] and columns :
67- col1 , col2 = columns
48+ if best_match in ["chi-square" , "t-test" , "correlation" , "pearson" ] and len ( columns ) >= 2 :
49+ col1 , col2 = columns [: 2 ]
6850 if best_match == "chi-square" :
6951 chi2 , p = chi_square_test (data , col1 , col2 )
7052 return f"Chi-square test result between { col1 } and { col2 } : chi2={ chi2 } , p={ p } "
@@ -76,12 +58,12 @@ def process_query(query, data):
7658 return f"Pearson correlation result between { col1 } and { col2 } : corr={ corr } , p_val={ p_val } "
7759 elif best_match == "summary" :
7860 return data_summary (data ).to_dict ()
79- elif best_match == "histogram" and columns :
61+ elif best_match == "histogram" and len ( columns ) >= 1 :
8062 column = columns [0 ]
8163 fig = plot_histogram (data , column )
8264 return fig , f"Histogram for { column } plotted."
83- elif best_match == "scatter plot" and columns :
84- col1 , col2 = columns
65+ elif best_match == "scatter plot" and len ( columns ) >= 2 :
66+ col1 , col2 = columns [: 2 ]
8567 fig = plot_scatter (data , col1 , col2 )
8668 return fig , f"Scatter plot for { col1 } and { col2 } plotted."
8769 else :
0 commit comments