summaryrefslogtreecommitdiff
path: root/search_code.py
blob: 9d9ecee798807088d88712398d7bd2a9459b0ddb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import ast
from collections import defaultdict
import os
import pandas as pd
import openai 
import tiktoken
from openai.embeddings_utils import get_embedding, cosine_similarity

openai.api_key = os.getenv('END_OF_WORLD')

#def clean():
#    df[col1] = df[col1].apply(lambda x: literal_eval(x) if "[" in x else x)

df=pd.read_csv("setup_dataWithSummaryEmbed.csv", converters={'embedding_summary': pd.eval})

#def phony(x):
#    print(type(x))
#    print(x)
#    exit()

def search_code(df, query, n=3, pprint=True):
    query_embedding = get_embedding(
        query,
        engine="text-embedding-ada-002"
    )

    #print(type(query_embedding))
    #print(query_embedding)

    #df["similarity"] = df.embedding_summary.apply(lambda x: phony(x))
    df["similarity"] = df.embedding_summary.apply(lambda x: cosine_similarity(x, query_embedding))

    results = (
        df.sort_values("similarity", ascending=False)
        
    )
    return results

def generate_answer(question):
  results = search_code(df, question, n=3)
  prompt = ''
  for i in range(3):
    prompt += results.iloc[i]["summary"] + "\n" + results.iloc[i]["blob"] + "\n"
   #prompt += "\n" + "Q: " + question + "\nA: "

  prompt += "\n" + "Answer the following question using the code context\
  given above, and show an example with 'Example'\nQ: " + question + "\nA: "

  print("PROMPT:")
  print(prompt)

  response = openai.Completion.create(
    model="text-davinci-003",
    prompt=prompt,
    temperature=0.7,
    max_tokens=1000,
    top_p=1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=["\"\"\""]
  )
  return response["choices"][0]["text"]

question = "how does the code in setup.py parse Python source code using the ast library?"
ans = generate_answer(question)
print(ans)