Mercurial > repos > bgruening > chatgpt_openai_api
comparison chatgpt.py @ 3:7770a4bd42e2 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/chatgpt commit c21d9a2cb410ee30dc47f4a13247862481816266
author | bgruening |
---|---|
date | Wed, 11 Sep 2024 16:36:21 +0000 |
parents | dab494dce303 |
children |
comparison
equal
deleted
inserted
replaced
2:dab494dce303 | 3:7770a4bd42e2 |
---|---|
1 import json | |
1 import os | 2 import os |
2 import sys | 3 import sys |
3 | 4 |
4 from openai import OpenAI | 5 from openai import AuthenticationError, OpenAI |
5 | 6 |
6 context_files = sys.argv[1].split(",") | 7 context_files = json.loads(sys.argv[1]) |
7 question = sys.argv[2] | 8 question = sys.argv[2] |
8 model = sys.argv[3] | 9 model = sys.argv[3] |
9 with open(sys.argv[4], "r") as f: | 10 with open(sys.argv[4], "r") as f: |
10 openai_api_key = f.read().strip() | 11 openai_api_key = f.read().strip() |
11 if not openai_api_key: | 12 if not openai_api_key: |
12 print("OpenAI API key is not provided in user preferences!") | 13 print("OpenAI API key is not provided in user preferences!") |
13 sys.exit(1) | 14 sys.exit(1) |
14 | 15 |
15 client = OpenAI(api_key=openai_api_key) | 16 client = OpenAI(api_key=openai_api_key) |
16 | 17 |
17 file_search_sup_ext = [ | |
18 "c", | |
19 "cs", | |
20 "cpp", | |
21 "doc", | |
22 "docx", | |
23 "html", | |
24 "java", | |
25 "json", | |
26 "md", | |
27 "pdf", | |
28 "php", | |
29 "pptx", | |
30 "py", | |
31 "rb", | |
32 "tex", | |
33 "txt", | |
34 "css", | |
35 "js", | |
36 "sh", | |
37 "ts", | |
38 ] | |
39 | |
40 vision_sup_ext = ["jpg", "jpeg", "png", "webp", "gif"] | |
41 | |
42 file_search_file_streams = [] | 18 file_search_file_streams = [] |
43 image_files = [] | 19 image_files = [] |
44 | 20 |
45 for path in context_files: | 21 for path, type in context_files: |
46 ext = path.split(".")[-1].lower() | 22 if type == "image": |
47 if ext in vision_sup_ext: | |
48 if os.path.getsize(path) > 20 * 1024 * 1024: | 23 if os.path.getsize(path) > 20 * 1024 * 1024: |
49 print(f"File {path} exceeds the 20MB limit and will not be processed.") | 24 print(f"File {path} exceeds the 20MB limit and will not be processed.") |
50 sys.exit(1) | 25 sys.exit(1) |
51 file = client.files.create(file=open(path, "rb"), purpose="vision") | 26 file = client.files.create(file=open(path, "rb"), purpose="vision") |
52 promt = {"type": "image_file", "image_file": {"file_id": file.id}} | 27 promt = {"type": "image_file", "image_file": {"file_id": file.id}} |
53 image_files.append(promt) | 28 image_files.append(promt) |
54 elif ext in file_search_sup_ext: | 29 else: |
55 file_search_file_streams.append(open(path, "rb")) | 30 file_search_file_streams.append(open(path, "rb")) |
56 | 31 |
57 assistant = client.beta.assistants.create( | 32 try: |
58 instructions="You will receive questions about files from file searches and image files. For file search queries, identify and retrieve the relevant files based on the question. For image file queries, analyze the image content and provide relevant information or insights based on the image data.", | 33 assistant = client.beta.assistants.create( |
59 model=model, | 34 instructions=( |
60 tools=[{"type": "file_search"}] if file_search_file_streams else [], | 35 "You will receive questions about files from file searches " |
61 ) | 36 "and image files. For file search queries, identify and " |
37 "retrieve the relevant files based on the question. " | |
38 "For image file queries, analyze the image content and " | |
39 "provide relevant information or insights based on the image data." | |
40 ), | |
41 model=model, | |
42 tools=[{"type": "file_search"}] if file_search_file_streams else [], | |
43 ) | |
44 except AuthenticationError as e: | |
45 print(f"Authentication error: {e.message}") | |
46 sys.exit(1) | |
47 except Exception as e: | |
48 print(f"An error occurred: {str(e)}") | |
49 sys.exit(1) | |
50 | |
62 if file_search_file_streams: | 51 if file_search_file_streams: |
63 vector_store = client.beta.vector_stores.create() | 52 vector_store = client.beta.vector_stores.create() |
64 file_batch = client.beta.vector_stores.file_batches.upload_and_poll( | 53 file_batch = client.beta.vector_stores.file_batches.upload_and_poll( |
65 vector_store_id=vector_store.id, files=file_search_file_streams | 54 vector_store_id=vector_store.id, files=file_search_file_streams |
66 ) | 55 ) |
86 thread_id=thread.id, assistant_id=assistant.id | 75 thread_id=thread.id, assistant_id=assistant.id |
87 ) | 76 ) |
88 assistant_messages = list( | 77 assistant_messages = list( |
89 client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) | 78 client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) |
90 ) | 79 ) |
91 | 80 if not assistant_messages: |
81 print( | |
82 "No output was generated!\nPlease ensure that your OpenAI account has sufficient credits.\n" | |
83 "You can check your balance here: https://platform.openai.com/settings/organization/billing" | |
84 ) | |
85 sys.exit(1) | |
92 message_content = assistant_messages[0].content[0].text.value | 86 message_content = assistant_messages[0].content[0].text.value |
93 print("Output has been saved!") | 87 print("Output has been saved!") |
94 with open("output.txt", "w") as f: | 88 with open("output.txt", "w") as f: |
95 f.write(message_content) | 89 f.write(message_content) |
96 | 90 |