56
56
from ._logger import set_verbose
57
57
from ._utils import suppress_stdout_stderr
58
58
59
- # LlamaX class that takes a string as input and returns that string
59
+ username = os .getenv ("USERNAME" )
60
+
61
+ def models ():
62
+ try :
63
+ model = CrossEncoder (model_name = "cross-encoder/ms-marco-TinyBERT-L-2" )
64
+ return model
65
+ except Exception as e :
66
+ print ("Please add your cross-encoder model." )
67
+
60
68
class LlamaX :
61
- def __init__ (self , model_path : str ):
62
- self .model_path = model_path
69
+ def __init__ (self ):
70
+ self .model_path = models ()
63
71
64
72
def get_model_path (self ) -> str :
65
73
return self .model_path
66
74
67
- # nltk dataloader. requires nltk to be installed
68
75
def nlLoader (nltkData ):
69
76
nltk_data_dir = Path (nltkData )
70
77
nltk .data .path .append (str (nltk_data_dir ))
71
78
72
- # Sentence-splitter function
73
79
def sentSplit ():
74
80
75
81
username = os .getenv ('USERNAME' )
@@ -84,7 +90,6 @@ def sentSplit():
84
90
if filename .endswith ('.txt' ):
85
91
with open (file_path , encoding = 'utf-8' ) as file :
86
92
document = file .read ()
87
- return document
88
93
89
94
## We split this article into paragraphs and then every paragraph into sentences
90
95
paragraphs = []
@@ -99,7 +104,6 @@ def sentSplit():
99
104
Please add your text dataset to the Data directory. Before continuing. Thank you!''' )
100
105
print (e )
101
106
102
- # Paragraph search function. Window-size may be adjusted
103
107
def passSearch ():
104
108
window_size = 3
105
109
passages = []
@@ -109,15 +113,15 @@ def passSearch():
109
113
passages .append (" " .join (paragraph [start_idx :end_idx ]))
110
114
return passages
111
115
112
- # Search in a loop for individual queries and predict the scores for the [query, passage] pairs
113
- def searchQuery (question , model_path : str ):
116
+
117
+ def searchQuery (question ):
114
118
query = []
115
119
query .append (question )
116
120
docs = []
117
121
118
122
for que in query :
119
123
try :
120
- model = model_path
124
+ model = models ()
121
125
# Concatenate the query and all passages and predict the scores for the pairs [query, passage]
122
126
model_inputs = [[que , passage ] for passage in passSearch ()]
123
127
scores = model .predict (model_inputs )
0 commit comments