fracapuano commited on
Commit
a134869
·
1 Parent(s): a35034f

fix: summarization pipeline restructuring

Browse files
Files changed (1) hide show
  1. summarization/summarization.py +88 -27
summarization/summarization.py CHANGED
@@ -1,43 +1,104 @@
1
  import streamlit as st
2
  from transformers import pipeline
 
 
 
3
 
4
  @st.cache_resource
5
- def summarization_model():
6
- model_name = "google/pegasus-xsum"
 
 
7
  summarizer = pipeline(
8
  model=model_name,
9
- tokenizer=model_name,
10
  task="summarization"
11
  )
12
  return summarizer
13
 
 
 
 
 
 
 
 
 
 
14
  def summarization_main():
15
- st.markdown("<h2 style='text-align: center; color:grey;'>Text Summarization</h2>", unsafe_allow_html=True)
16
- st.markdown("<h3 style='text-align: left; color:#F63366; font-size:18px;'><b>What is text summarization about?<b></h3>", unsafe_allow_html=True)
17
- st.write("Text summarization is producing a shorter version of a given text while preserving its important information.")
18
- st.markdown('___')
19
- source = st.radio("How would you like to start? Choose an option below", ["I want to input some text", "I want to upload a file"])
20
- if source == "I want to input some text":
 
 
 
 
 
 
 
 
 
 
21
  sample_text = ""
22
- text = st.text_area("Input a text in English (10,000 characters max) or use the example below", value=sample_text, max_chars=10000, height=330)
 
 
 
 
 
 
23
 
24
- button = st.button("Get summary")
25
- if button:
26
- with st.spinner(text="Loading summarization model..."):
27
- summarizer = summarization_model()
28
- with st.spinner(text="Summarizing text..."):
29
- summary = summarizer(text, max_length=130, min_length=30)
30
- st.text(summary[0]["summary_text"])
31
-
32
- elif source == "I want to upload a file":
33
- uploaded_file = st.file_uploader("Choose a .txt file to upload", type=["txt"])
34
  if uploaded_file is not None:
35
- raw_text = str(uploaded_file.read(),"utf-8")
36
- text = st.text_area("", value=raw_text, height=330)
37
- button = st.button("Get summary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if button:
39
- with st.spinner(text="Loading summarization model..."):
40
- summarizer = summarization_model()
41
  with st.spinner(text="Summarizing text..."):
42
- summary = summarizer(text, max_length=130, min_length=30)
43
- st.text(summary[0]["summary_text"])
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from qa.qa import file_to_doc
4
+ from transformers import AutoTokenizer
5
+ from typing import Text, Union
6
 
7
  @st.cache_resource
8
+ def summarization_model(
9
+ model_name:str="facebook/bart-large-cnn",
10
+ custom_tokenizer:Union[AutoTokenizer, bool]=False
11
+ ):
12
  summarizer = pipeline(
13
  model=model_name,
14
+ tokenizer=model_name if custom_tokenizer==False else custom_tokenizer,
15
  task="summarization"
16
  )
17
  return summarizer
18
 
19
+ @st.cache_data
20
+ def split_string_into_token_chunks(s:Text, _tokenizer:AutoTokenizer, chunk_size:int):
21
+ # Tokenize the entire string
22
+ token_ids = _tokenizer.encode(s)
23
+ # Split the token ids into chunks of the desired size
24
+ chunks = [token_ids[i:i+chunk_size] for i in range(0, len(token_ids), chunk_size)]
25
+ # Decode each chunk back into a string
26
+ return [_tokenizer.decode(chunk) for chunk in chunks]
27
+
28
  def summarization_main():
29
+ st.markdown("<h2 style='text-align: center'>Text Summarization</h2>", unsafe_allow_html=True)
30
+ st.markdown("<h3 style='text-align: left'><b>What is text summarization about?<b></h3>", unsafe_allow_html=True)
31
+
32
+ st.write("""
33
+ Text summarization is common NLP task concerned with producing a shorter version of a given text while preserving the important information
34
+ contained in such text
35
+ """)
36
+
37
+ OPTION_1 = "I want to input some text"
38
+ OPTION_2 = "I want to upload a file"
39
+ # option = st.radio("How would you like to start? Choose an option below", [OPTION_1, OPTION_2])
40
+ option = OPTION_2
41
+
42
+ # greenlight to summarize
43
+ text_is_given = False
44
+ if option == OPTION_1:
45
  sample_text = ""
46
+ text = st.text_area(
47
+ "Input a text in English (10,000 characters max)",
48
+ value=sample_text,
49
+ max_chars=10_000,
50
+ height=330)
51
+ # toggle text is given greenlight
52
+ text_is_given = not text_is_given
53
 
54
+ elif option == OPTION_2:
55
+ uploaded_file = st.file_uploader(
56
+ "Upload a pdf, docx, or txt file (scanned documents not supported)",
57
+ type=["pdf", "docx", "txt"],
58
+ help="Scanned documents are not supported yet 🥲"
59
+ )
 
 
 
 
60
  if uploaded_file is not None:
61
+ # parse the file using custom parsers and build a concatenation for the summarizer
62
+ text = " ".join(file_to_doc(uploaded_file))
63
+ # toggle text is given greenlight
64
+ text_is_given = not text_is_given
65
+
66
+ if text_is_given:
67
+ # minimal number of words in the summary
68
+ min_length, max_length = 30, 200
69
+ user_max_length = max_length
70
+ # user_max_lenght = st.slider(
71
+ # label="Maximal number of tokens in the summary",
72
+ # min_value=min_length,
73
+ # max_value=max_length,
74
+ # value=150,
75
+ # step=10,
76
+ # )
77
+
78
+ summarizer_downloaded = False
79
+ # loading the tokenizer to split the input document into feasible chunks
80
+ model_name = "facebook/bart-large-cnn"
81
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
82
+
83
+ # the maximum number of tokens the model can handle depends on the model - accounting for tokens added by tokenizer
84
+ chunk_size = int(0.88*tokenizer.model_max_length)
85
+
86
+ # loading the summarization model considered
87
+ with st.spinner(text="Loading summarization model..."):
88
+ summarizer = summarization_model(model_name=model_name)
89
+ summarizer_downloaded = True
90
+
91
+ if summarizer_downloaded:
92
+ button = st.button("Summarize!")
93
  if button:
 
 
94
  with st.spinner(text="Summarizing text..."):
95
+ # summarizing each chunk of the input text to avoid exceeding the maximum number of tokens
96
+ summary = ""
97
+ chunks = split_string_into_token_chunks(text, tokenizer, chunk_size)
98
+ for chunk in chunks:
99
+ print(len(tokenizer.encode(chunk)))
100
+ chunk_summary = summarizer(chunk, max_length=user_max_length, min_length=min_length)
101
+ summary += chunk_summary[0]["summary_text"]
102
+
103
+ st.markdown("<h3 style='text-align: left'><b>Summary<b></h3>", unsafe_allow_html=True)
104
+ st.markdown(summary)