Spaces:
Sleeping
Sleeping
| from utils.models import get_bm25_model, preprocess_text | |
| import numpy as np | |
| # BM25 Filtering and Retrieval | |
| def filter_data_docs(data, ticker, quarter, year): | |
| year_int = int(year) | |
| data_subset = data[ | |
| (data["Year"] == year_int) | |
| & (data["Quarter"] == quarter) | |
| & (data["Ticker"] == ticker) | |
| ] | |
| return data_subset | |
| def get_bm25_search_hits(corpus, sparse_scores, top_n=50): | |
| bm25_search = [] | |
| indices = [] | |
| for idx in sparse_scores: | |
| if len(bm25_search) <= top_n: | |
| bm25_search.append(corpus[idx]) | |
| indices.append(idx) | |
| indices = [int(x) for x in indices] | |
| return indices | |
| # BM-25 Filtering | |
| def get_indices_bm25( | |
| data, query, ticker=None, quarter=None, year=None, num_candidates=50 | |
| ): | |
| if ticker is None or quarter is None or year is None: | |
| corpus, bm25 = get_bm25_model(data) | |
| else: | |
| filtered_data = filter_data_docs(data, ticker, quarter, year) | |
| corpus, bm25 = get_bm25_model(filtered_data) | |
| tokenized_query = preprocess_text(query).split() | |
| sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1] | |
| indices_hits = get_bm25_search_hits(corpus, sparse_scores, num_candidates) | |
| return indices_hits | |
| def query_pinecone( | |
| dense_vec, | |
| top_k, | |
| index, | |
| year=None, | |
| quarter=None, | |
| ticker=None, | |
| keywords=None, | |
| indices=None, | |
| threshold=0.25, | |
| ): | |
| filter_dict = { | |
| "QA_Flag": {"$eq": "Answer"}, | |
| } | |
| if year is not None: | |
| filter_dict["Year"] = int(year) | |
| if quarter is not None: | |
| filter_dict["Quarter"] = {"$eq": quarter} | |
| if ticker is not None: | |
| filter_dict["Ticker"] = {"$eq": ticker} | |
| if keywords is not None: | |
| filter_dict["Keywords"] = {"$in": keywords} | |
| if indices is not None: | |
| filter_dict["index"] = {"$in": indices} | |
| xc = index.query( | |
| vector=dense_vec, | |
| top_k=top_k, | |
| filter=filter_dict, | |
| include_metadata=True, | |
| ) | |
| # filter the context passages based on the score threshold | |
| filtered_matches = [] | |
| for match in xc["matches"]: | |
| if match["score"] >= threshold: | |
| filtered_matches.append(match) | |
| xc["matches"] = filtered_matches | |
| return xc | |
| def sentence_id_combine(data, query_results, lag=1): | |
| # Extract sentence IDs from query results | |
| ids = [ | |
| result["metadata"]["Sentence_id"] | |
| for result in query_results["matches"] | |
| ] | |
| # Generate new IDs by adding a lag value to the original IDs | |
| new_ids = [id + i for id in ids for i in range(-lag, lag + 1)] | |
| # Remove duplicates and sort the new IDs | |
| new_ids = sorted(set(new_ids)) | |
| # Create a list of lookup IDs by grouping the new IDs in groups of lag*2+1 | |
| lookup_ids = [ | |
| new_ids[i : i + (lag * 2 + 1)] | |
| for i in range(0, len(new_ids), lag * 2 + 1) | |
| ] | |
| # Create a list of context sentences by joining the sentences | |
| # corresponding to the lookup IDs | |
| context_list = [ | |
| " ".join( | |
| data.loc[data["Sentence_id"].isin(lookup_id), "Text"].to_list() | |
| ) | |
| for lookup_id in lookup_ids | |
| ] | |
| context = " ".join(context_list).strip() | |
| return context | |