| |
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import random |
| from matplotlib.ticker import MaxNLocator |
| from transformers import pipeline |
|
|
| MODEL_NAMES = ["bert-base-uncased", "roberta-base", "bert-large-uncased", "roberta-large"] |
| OWN_MODEL_NAME = 'add-a-model' |
|
|
| DECIMAL_PLACES = 1 |
| EPS = 1e-5 |
|
|
| |
| DATE_SPLIT_KEY = "DATE" |
| START_YEAR = 1801 |
| STOP_YEAR = 1999 |
| NUM_PTS = 20 |
| DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist() |
| DATES = [f'{d}' for d in DATES] |
|
|
| |
| |
| |
| PLACE_SPLIT_KEY = "PLACE" |
| PLACES = [ |
| "Afghanistan", |
| "Yemen", |
| "Iraq", |
| "Pakistan", |
| "Syria", |
| "Democratic Republic of Congo", |
| "Iran", |
| "Mali", |
| "Chad", |
| "Saudi Arabia", |
| "Switzerland", |
| "Ireland", |
| "Lithuania", |
| "Rwanda", |
| "Namibia", |
| "Sweden", |
| "New Zealand", |
| "Norway", |
| "Finland", |
| "Iceland"] |
|
|
|
|
| |
| |
| |
| SUBREDDITS = [ |
| "GlobalOffensive", |
| "pcmasterrace", |
| "nfl", |
| "sports", |
| "The_Donald", |
| "leagueoflegends", |
| "Overwatch", |
| "gonewild", |
| "Futurology", |
| "space", |
| "technology", |
| "gaming", |
| "Jokes", |
| "dataisbeautiful", |
| "woahdude", |
| "askscience", |
| "wow", |
| "anime", |
| "BlackPeopleTwitter", |
| "politics", |
| "pokemon", |
| "worldnews", |
| "reddit.com", |
| "interestingasfuck", |
| "videos", |
| "nottheonion", |
| "television", |
| "science", |
| "atheism", |
| "movies", |
| "gifs", |
| "Music", |
| "trees", |
| "EarthPorn", |
| "GetMotivated", |
| "pokemongo", |
| "news", |
| |
| |
| |
| "Fitness", |
| "Showerthoughts", |
| "OldSchoolCool", |
| "explainlikeimfive", |
| "todayilearned", |
| "gameofthrones", |
| "AdviceAnimals", |
| "DIY", |
| "WTF", |
| "IAmA", |
| "cringepics", |
| "tifu", |
| "mildlyinteresting", |
| "funny", |
| "pics", |
| "LifeProTips", |
| "creepy", |
| "personalfinance", |
| "food", |
| "AskReddit", |
| "books", |
| "aww", |
| "sex", |
| "relationships", |
| ] |
|
|
| GENDERED_LIST = [ |
| ['he', 'she'], |
| ['him', 'her'], |
| ['his', 'hers'], |
| ["himself", "herself"], |
| ['male', 'female'], |
| ['man', 'woman'], |
| ['men', 'women'], |
| ["husband", "wife"], |
| ['father', 'mother'], |
| ['boyfriend', 'girlfriend'], |
| ['brother', 'sister'], |
| ["actor", "actress"], |
| ] |
|
|
| |
| |
| models = dict() |
|
|
| for bert_like in MODEL_NAMES: |
| models[bert_like] = pipeline("fill-mask", model=bert_like) |
|
|
| |
|
|
|
|
| def get_gendered_token_ids(): |
| male_gendered_tokens = [list[0] for list in GENDERED_LIST] |
| female_gendered_tokens = [list[1] for list in GENDERED_LIST] |
|
|
| return male_gendered_tokens, female_gendered_tokens |
|
|
|
|
| def prepare_text_for_masking(input_text, mask_token, gendered_tokens, split_key): |
| text_w_masks_list = [ |
| mask_token if word.lower() in gendered_tokens else word for word in input_text.split()] |
| num_masks = len([m for m in text_w_masks_list if m == mask_token]) |
|
|
| text_portions = ' '.join(text_w_masks_list).split(split_key) |
| return text_portions, num_masks |
|
|
|
|
| def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_preds): |
| pronoun_preds = [sum([ |
| pronoun["score"] if pronoun["token_str"].strip().lower() in gendered_token else 0.0 |
| for pronoun in top_preds]) |
| for top_preds in mask_filled_text |
| ] |
| return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES) |
|
|
| |
|
|
|
|
| def get_figure(df, gender, n_fit=1, model_name=None): |
| df = df.set_index('x-axis') |
| cols = df.columns |
| xs = list(range(len(df))) |
| ys = df[cols[0]] |
| fig, ax = plt.subplots() |
| |
| fig.set_figheight(3) |
| fig.set_figwidth(9) |
|
|
| |
| p, C_p = np.polyfit(xs, ys, n_fit, cov=1) |
| t = np.linspace(min(xs)-1, max(xs)+1, 10*len(xs)) |
| TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T |
|
|
| |
| yi = np.dot(TT, p) |
| C_yi = np.dot(TT, np.dot(C_p, TT.T)) |
| sig_yi = np.sqrt(np.diag(C_yi)) |
|
|
| ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha=.25) |
| ax.plot(t, yi, '-') |
| ax.plot(df, 'ro') |
| ax.legend(list(df.columns)) |
|
|
| ax.axis('tight') |
| ax.set_xlabel("Value injected into input text") |
| ax.set_title( |
| f"Probability of predicting {gender} pronouns on {model_name}.") |
| ax.set_ylabel(f"Softmax prob for pronouns") |
| ax.xaxis.set_major_locator(MaxNLocator(6)) |
| ax.tick_params(axis='x', labelrotation=5) |
| return fig |
|
|
|
|
| |
| def predict_gender_pronouns( |
| model_name, |
| own_model_name, |
| indie_vars, |
| split_key, |
| normalizing, |
| n_fit, |
| input_text, |
| ): |
| """Run inference on input_text for each model type, returning df and plots of percentage |
| of gender pronouns predicted as female and male in each target text. |
| """ |
| if model_name not in MODEL_NAMES: |
| model = pipeline("fill-mask", model=own_model_name) |
| model_name = OWN_MODEL_NAME |
| else: |
| model = models[model_name] |
|
|
| mask_token = model.tokenizer.mask_token |
|
|
| indie_vars_list = indie_vars.split(',') |
|
|
| male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids() |
|
|
| text_segments, num_preds = prepare_text_for_masking( |
| input_text, mask_token, male_gendered_tokens + female_gendered_tokens, split_key) |
|
|
| male_pronoun_preds = [] |
| female_pronoun_preds = [] |
| for indie_var in indie_vars_list: |
|
|
| target_text = f"{indie_var}".join(text_segments) |
| mask_filled_text = model(target_text) |
| |
| if type(mask_filled_text[0]) is not list: |
| mask_filled_text = [mask_filled_text] |
|
|
| female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( |
| mask_filled_text, |
| female_gendered_tokens, |
| num_preds |
| )) |
| male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( |
| mask_filled_text, |
| male_gendered_tokens, |
| num_preds |
| )) |
|
|
| if normalizing: |
| total_gendered_probs = np.add( |
| female_pronoun_preds, male_pronoun_preds) |
| female_pronoun_preds = np.around( |
| np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100, |
| decimals=DECIMAL_PLACES |
| ) |
| male_pronoun_preds = np.around( |
| np.divide(male_pronoun_preds, total_gendered_probs+EPS)*100, |
| decimals=DECIMAL_PLACES |
| ) |
|
|
| results_df = pd.DataFrame({'x-axis': indie_vars_list}) |
| results_df['female_pronouns'] = female_pronoun_preds |
| results_df['male_pronouns'] = male_pronoun_preds |
| female_fig = get_figure(results_df.drop( |
| 'male_pronouns', axis=1), 'female', n_fit, model_name) |
| male_fig = get_figure(results_df.drop( |
| 'female_pronouns', axis=1), 'male', n_fit, model_name) |
| display_text = f"{random.choice(indie_vars_list)}".join(text_segments) |
|
|
| return ( |
| display_text, |
| female_fig, |
| male_fig, |
| results_df, |
| ) |
|
|
|
|
| |
| title = "Causing Gender Pronouns" |
| description = """ |
| ## Intro |
| """ |
|
|
|
|
| date_example = [ |
| MODEL_NAMES[1], |
| '', |
| ', '.join(DATES), |
| 'DATE', |
| "False", |
| 1, |
| 'She was a teenager in DATE.' |
| ] |
|
|
|
|
| place_example = [ |
| MODEL_NAMES[0], |
| '', |
| ', '.join(PLACES), |
| 'PLACE', |
| "False", |
| 1, |
| 'She became an adult in PLACE.' |
| ] |
|
|
|
|
| subreddit_example = [ |
| MODEL_NAMES[3], |
| '', |
| ', '.join(SUBREDDITS), |
| 'SUBREDDIT', |
| "False", |
| 1, |
| 'She was a kid. SUBREDDIT.' |
| ] |
|
|
| own_model_example = [ |
| OWN_MODEL_NAME, |
| 'emilyalsentzer/Bio_ClinicalBERT', |
| ', '.join(DATES), |
| 'DATE', |
| "False", |
| 1, |
| 'She was exposed to the virus in DATE.' |
| ] |
|
|
|
|
| def date_fn(): |
| return date_example |
|
|
|
|
| def place_fn(): |
| return place_example |
|
|
|
|
| def reddit_fn(): |
| return subreddit_example |
|
|
|
|
| def your_fn(): |
| return own_model_example |
|
|
|
|
| |
| demo = gr.Blocks() |
| with demo: |
| gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs") |
| gr.Markdown("Find spurious correlations between seemingly independent variables (for example between `gender` and `time`) in almost any BERT-like LLM on Hugging Face, below.") |
|
|
| |
|
|
| |
| gr.Markdown("## Instructions for this Demo") |
| gr.Markdown("1) Click on one of the examples below (where we sweep through a spectrum of `places`, `dates` and `subreddits`) to pre-populate the input fields.") |
| gr.Markdown("2) Check out the pre-populated fields as you scroll down to the ['Hit Submit...'] button!") |
| gr.Markdown("3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!") |
|
|
| gr.Markdown("## Example inputs") |
| gr.Markdown("Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions.") |
| with gr.Row(): |
| date_gen = gr.Button('Click for date example inputs') |
| gr.Markdown("<-- x-axis sorted by older to more recent dates:") |
|
|
| place_gen = gr.Button('Click for country example inputs') |
| gr.Markdown( |
| "<-- x-axis sorted by bottom 10 and top 10 [Global Gender Gap](https://www3.weforum.org/docs/WEF_GGGR_2021.pdf) ranked countries:") |
|
|
| subreddit_gen = gr.Button('Click for Subreddit example inputs') |
| gr.Markdown( |
| "<-- x-axis sorted in order of increasing self-identified female participation (see [bburky](http://bburky.com/subredditgenderratios/)): ") |
|
|
| your_gen = gr.Button('Add-a-model example inputs') |
| gr.Markdown("<-- x-axis dates, with your own model loaded! (If first time, try another example, it can take a while to load new model.)") |
|
|
| gr.Markdown("## Input fields") |
| gr.Markdown( |
| f"A) Pick a spectrum of comma separated values for text injection and x-axis.") |
|
|
| with gr.Row(): |
| x_axis = gr.Textbox( |
| lines=3, |
| label="A) Comma separated values for text injection and x-axis", |
| ) |
|
|
|
|
| gr.Markdown("B) Pick a pre-loaded BERT-family model of interest on the right.") |
| gr.Markdown(f"Or C) select `{OWN_MODEL_NAME}`, then add the name of any other Hugging Face model that supports the [fill-mask](https://huggingface.co/models?pipeline_tag=fill-mask) task on the right (note: this may take some time to load).") |
|
|
| with gr.Row(): |
| model_name = gr.Radio( |
| MODEL_NAMES + [OWN_MODEL_NAME], |
| type="value", |
| label="B) BERT-like model.", |
| ) |
| own_model_name = gr.Textbox( |
| label="C) If you selected an 'add-a-model' model, put any Hugging Face pipeline model name (that supports the fill-mask task) here.", |
| ) |
|
|
| gr.Markdown("D) Pick if you want to the predictions normalied to these gendered terms only.") |
| gr.Markdown("E) Also tell the demo what special token you will use in your input text, that you would like replaced with the spectrum of values you listed above.") |
| gr.Markdown("And F) the degree of polynomial fit used for high-lighting potential spurious association.") |
|
|
|
|
| with gr.Row(): |
| to_normalize = gr.Dropdown( |
| ["False", "True"], |
| label="D) Normalize model's predictions to only the gendered ones?", |
| type="index", |
| ) |
| place_holder = gr.Textbox( |
| label="E) Special token place-holder", |
| ) |
| n_fit = gr.Dropdown( |
| list(range(1, 5)), |
| label="F) Degree of polynomial fit", |
| type="value", |
| ) |
|
|
| gr.Markdown( |
| "G) Finally, add input text that includes at least one gendered pronouns and one place-holder token specified above.") |
|
|
| with gr.Row(): |
| input_text = gr.Textbox( |
| lines=2, |
| label="G) Input text with pronouns and place-holder token", |
| ) |
|
|
| gr.Markdown("## Outputs!") |
| |
| with gr.Row(): |
| btn = gr.Button("Hit submit to generate predictions!") |
|
|
| with gr.Row(): |
| sample_text = gr.Textbox( |
| type="auto", label="Output text: Sample of text fed to model") |
| with gr.Row(): |
| female_fig = gr.Plot(type="auto") |
| male_fig = gr.Plot(type="auto") |
| with gr.Row(): |
| df = gr.Dataframe( |
| show_label=True, |
| overflow_row_behaviour="show_ends", |
| label="Table of softmax probability for pronouns predictions", |
| ) |
|
|
| with gr.Row(): |
|
|
| date_gen.click(date_fn, inputs=[], outputs=[model_name, own_model_name, |
| x_axis, place_holder, to_normalize, n_fit, input_text]) |
| place_gen.click(place_fn, inputs=[], outputs=[ |
| model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) |
| subreddit_gen.click(reddit_fn, inputs=[], outputs=[ |
| model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) |
| your_gen.click(your_fn, inputs=[], outputs=[ |
| model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) |
|
|
| btn.click( |
| predict_gender_pronouns, |
| inputs=[model_name, own_model_name, x_axis, place_holder, |
| to_normalize, n_fit, input_text], |
| outputs=[sample_text, female_fig, male_fig, df]) |
|
|
|
|
| demo.launch(debug=True) |
|
|