summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Buttke <nate-web@riseup.net>2023-07-25 22:40:00 -0700
committerNate Buttke <nate-web@riseup.net>2023-07-25 22:40:00 -0700
commit0985b7f2d467ecbeba0c6ca51ba03236cd4ff929 (patch)
tree5c24e8f12cd4416c69c5a37c365af34a1119a47f
hi dan
-rw-r--r--README.md7
-rw-r--r--search_code.py66
-rw-r--r--setup.py158
-rw-r--r--setup_cont.py16
4 files changed, 247 insertions, 0 deletions
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..001f36b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,7 @@
+forked from https://github.com/keerthanpg/TalkToCode
+
+really hard to explain and messy right now. Refactor soon. You need to install
+tree-sitter and make an object file to parse `go`, the only language that works
+right now.
+
+Lots of stuff is hard coded.
diff --git a/search_code.py b/search_code.py
new file mode 100644
index 0000000..9d9ecee
--- /dev/null
+++ b/search_code.py
@@ -0,0 +1,66 @@
+import ast
+from collections import defaultdict
+import os
+import pandas as pd
+import openai
+import tiktoken
+from openai.embeddings_utils import get_embedding, cosine_similarity
+
+openai.api_key = os.getenv('END_OF_WORLD')
+
+#def clean():
+# df[col1] = df[col1].apply(lambda x: literal_eval(x) if "[" in x else x)
+
+df=pd.read_csv("setup_dataWithSummaryEmbed.csv", converters={'embedding_summary': pd.eval})
+
+#def phony(x):
+# print(type(x))
+# print(x)
+# exit()
+
+def search_code(df, query, n=3, pprint=True):
+ query_embedding = get_embedding(
+ query,
+ engine="text-embedding-ada-002"
+ )
+
+ #print(type(query_embedding))
+ #print(query_embedding)
+
+ #df["similarity"] = df.embedding_summary.apply(lambda x: phony(x))
+ df["similarity"] = df.embedding_summary.apply(lambda x: cosine_similarity(x, query_embedding))
+
+ results = (
+ df.sort_values("similarity", ascending=False)
+
+ )
+ return results
+
+def generate_answer(question):
+ results = search_code(df, question, n=3)
+ prompt = ''
+ for i in range(3):
+ prompt += results.iloc[i]["summary"] + "\n" + results.iloc[i]["blob"] + "\n"
+ #prompt += "\n" + "Q: " + question + "\nA: "
+
+ prompt += "\n" + "Answer the following question using the code context\
+ given above, and show an example with 'Example'\nQ: " + question + "\nA: "
+
+ print("PROMPT:")
+ print(prompt)
+
+ response = openai.Completion.create(
+ model="text-davinci-003",
+ prompt=prompt,
+ temperature=0.7,
+ max_tokens=1000,
+ top_p=1.0,
+ frequency_penalty=0.0,
+ presence_penalty=0.0,
+ stop=["\"\"\""]
+ )
+ return response["choices"][0]["text"]
+
+question = "how does the code in setup.py parse Python source code using the ast library?"
+ans = generate_answer(question)
+print(ans)
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..80be931
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,158 @@
+import ast
+from collections import defaultdict
+import os
+import pandas as pd
+import openai
+import tiktoken
+from openai.embeddings_utils import get_embedding, cosine_similarity
+
+from tree_sitter import Language, Parser
+
+SOURCE_DIR = './'
+
+openai.api_key = os.getenv('END_OF_WORLD')
+
+#def get_block(code, node, code_type, file_path):
+# """combine a bunch of data about a function. return dictionary"""
+# blob = f"{node['pretext']}{ast.get_source_segment(code, node['node'])}"
+# return {
+# 'code_type': code_type,
+# 'source': blob,
+# 'start_line': node['node'].lineno,
+# 'end_line': node['node'].end_lineno,
+# 'chars': len(blob),
+# 'file_path': file_path
+# }
+
+def ts_query(lang, tree, sexp):
+ query = lang.query(sexp)
+ return query.captures(tree.root_node)
+
+def ts_get_all_code_blocks(lang, code_blocks, file_path, tree, code):
+ """Use treesitter to get all code blocks"""
+
+ # TODO need way to switch between declaration and definition ..
+ # e.g. golang does not have function definitions according to treesitter
+ results = ts_query(lang, tree, """(function_declaration) @function""")
+ results += ts_query(lang, tree, """(method_declaration) @method""")
+
+ # TODO something like list comprehension here
+ for r in results:
+ return_dict = {
+ 'code_type': r[1],
+ 'source': code[r[0].start_byte:r[0].end_byte].decode('utf-8'),
+ 'start_line': r[0].start_point[0],
+ 'end_line': r[0].end_point[0],
+ 'chars': r[0].end_byte - r[0].start_byte,
+ 'file_path': file_path
+ }
+ code_blocks.append(return_dict)
+
+
+def ts_get_all_code_blocks_old(code_blocks, file_path, node):
+ """Use treesitter to get all code blocks"""
+ #dict has'code_type' 'source' 'start_line' 'end_line' 'chars' 'file_path'
+ #print('HERRO', type(node))
+ for child in node.children:
+ #print(type(child), child)
+ return_dict = {
+ 'code_type': child.type,
+ 'start_line': child.start_point[0],
+ 'end_line': child.end_point[0],
+ 'chars': child.end_byte - child.start_byte,
+ 'file_path': file_path
+ }
+ code_blocks.append(return_dict)
+ #if child.type != "function_definition" and len(child.children)
+ ts_get_all_code_blocks(code_blocks, file_path, child)
+
+def parse_file(file_path):
+ """take source code file and return pd dataframe"""
+ # read file
+ with open(file_path, 'r') as f:
+ code = f.read()
+
+ # Tree-Sitter
+ parser = Parser()
+ lang = Language("./tree-go.so", "go")
+ parser.set_language(lang)
+ tree = parser.parse(bytes(code, "utf8"))
+
+ code_blocks = []
+ ts_get_all_code_blocks(lang, code_blocks, file_path, tree, bytes(code, "utf8"))
+
+
+ #TODO
+ # collate imports, assign
+ collate_types = ['import', 'assign']
+ tempblock = None
+ finblocks = []
+
+ for block in code_blocks:
+ if block['code_type'] in collate_types:
+ if tempblock is None:
+ tempblock = {k:v for k,v in block.items()}
+ elif tempblock['code_type'] == block['code_type']:
+ tempblock['source'] += f"\n{block['source']}"
+ tempblock['start_line'] = min(tempblock['start_line'], block['start_line'])
+ tempblock['end_line'] = max(tempblock['start_line'], block['end_line'])
+ tempblock['chars'] += (block['chars'] + 1)
+ else:
+ finblocks.append(tempblock)
+ tempblock = {k:v for k,v in block.items()}
+ else:
+ if tempblock is not None:
+ finblocks.append(tempblock)
+ tempblock = None
+ finblocks.append(block)
+ df = pd.DataFrame(finblocks)
+ return df
+
+
+def get_files_to_parse(root_path, files_extensions_to_parse=['py'], dirs_to_ignore=['tests']) -> list:
+ """get all source file paths as list."""
+ files_to_parse = []
+ for root, dirs, files in os.walk(SOURCE_DIR):
+ for name in files:
+ if (root.rsplit("/", 1)[-1] in dirs_to_ignore) or (name.rsplit('.')[-1] not in files_extensions_to_parse):
+ continue
+ temp_path = os.path.join(root, name)
+ files_to_parse.append(temp_path)
+ return files_to_parse
+
+def generate_summary(prompt):
+ prompt = prompt + '\nSummarize the above code: '
+ response = openai.Completion.create(
+ model="text-davinci-003",
+ prompt=prompt,
+ temperature=0.7,
+ max_tokens=1024,
+ top_p=1.0,
+ frequency_penalty=0.0,
+ presence_penalty=0.0,
+ stop=["\"\"\""]
+ )
+ return response["choices"][0]["text"]
+
+# nate function to create blob. the blob just contains the file path and the source code.
+def blobify(pandaSeries):
+ return f"file path: {pandaSeries['file_path']}\n {pandaSeries['source']}"
+
+
+### doing stuff!!
+
+df = parse_file("../../dirserver/src/dirserver/fdpoller.go")
+df.to_csv('test.csv')
+df["blob"] = df.apply(lambda x: blobify(x),axis=1)
+
+print(type(df))
+print(df)
+
+df.to_csv('test_with_blob.csv')
+
+print('startng to generate summary')
+df["summary"] = df.blob.apply(lambda x: generate_summary(x))
+print('done with generate summary')
+
+df.to_csv('test_with_summary.csv')
+
diff --git a/setup_cont.py b/setup_cont.py
new file mode 100644
index 0000000..360c9f9
--- /dev/null
+++ b/setup_cont.py
@@ -0,0 +1,16 @@
+import ast
+from collections import defaultdict
+import os
+import pandas as pd
+import openai
+import tiktoken
+from openai.embeddings_utils import get_embedding, cosine_similarity
+
+openai.api_key = os.getenv('END_OF_WORLD')
+
+df=pd.read_csv("setup_dataWithSummary.csv")
+embedding_model = "text-embedding-ada-002"
+df["embedding_summary"] = df.summary.apply([lambda x: get_embedding(x, engine=embedding_model)])
+print(df)
+
+df.to_csv('setup_dataWithSummaryEmbed.csv')