-
Notifications
You must be signed in to change notification settings - Fork 0
/
generation.jl
120 lines (107 loc) · 3.87 KB
/
generation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
export gen_code
using ProgressMeter
using OpenAI
using PromptingTools
const PT = PromptingTools
function gen_prompt(model, task::HumanEvalTask; chain_of_thought=false)
cot = chain_of_thought ? " You need first to write a step-by-step outline and then write the code." : ""
[
Dict(
"role" => "system",
"content" => "You are an intelligent programming assistant to produce Julia algorithmic solutions. Please implement the following Julia function based on given doc string.$cot",
),
Dict("role" => "user", "content" => task.prompt),
]
end
function gen_prompt(model::Val{Symbol("Yi-34B-Chat")}, task::HumanEvalTask; chain_of_thought=false)
cot = chain_of_thought ? " You need first to write a step-by-step outline and then write the code." : ""
[
Dict("role" => "user", "content" => """You are an intelligent programming assistant to produce Julia algorithmic solutions. Please implement the following Julia function in the markdown format based on given doc string.$cot
```julia
$(task.prompt)
```
"""),
]
end
function gen_reply(model, task::HumanEvalTask; chain_of_thought=false, kw...)
prompt = gen_prompt(Val(Symbol(model)), task; chain_of_thought)
if haskey(ENV, "ANTHROPIC_API_KEY")
res = PT.anthropic_api(PT.AnthropicSchema(); messages=[prompt[end]], model=model, url=ENV["ANTHROPIC_BASE_URL"], api_key=ENV["ANTHROPIC_API_KEY"], system=prompt[begin]["content"])
[c[:text] for c in res.response[:content]]
else
provider = OpenAI.OpenAIProvider(
api_key=ENV["OPENAI_API_KEY"],
base_url=ENV["OPENAI_BASE_URL"],
)
r = create_chat(provider, model, prompt; kw...)
[c[:message][:content] for c in r.response[:choices]]
end
end
"""
```
generations
model_name
temperature
seed_1
task_name.jl
task_name.prompt.txt
task_name.generation.txt
...
seed_2
...
```
"""
function gen_code(model; temperature=0.0, n_samples=200, batch_size=10, chain_of_thought=false, kw...)
if temperature == 0.0
n_samples, batch_size = 1, 1
end
GEN_DIR = joinpath(
@__DIR__,
"..",
"generations",
chain_of_thought ? "$(model)-COT" : model,
"$temperature",
)
for i = 1:n_samples
mkpath(joinpath(GEN_DIR, "$i"))
end
# gen response
@showprogress desc = "[$model]Generating..." for t in get_tasks()
start = findfirst(
i -> !isfile(joinpath(GEN_DIR, "$i", "$(nameof(t)).generation.txt")),
1:n_samples,
)
isnothing(start) && continue
for i = start:batch_size:n_samples
reply =
gen_reply(model, t; temperature, n=min(i + batch_size, n_samples + 1) - i, chain_of_thought, kw...)
for (j, r) in enumerate(reply)
open(
joinpath(GEN_DIR, "$(i+j-1)", "$(nameof(t)).generation.txt"),
"w",
) do io
write(io, r)
end
end
end
end
# extract julia code
for t in get_tasks()
for i in 1:n_samples
s = read(joinpath(GEN_DIR, "$i", "$(nameof(t)).generation.txt"), String)
open(joinpath(GEN_DIR, "$i", "$(nameof(t)).jl"), "w") do io
snippets = extract_julia_code_blocks(s)
if !isempty(snippets)
println(io, t.prompt)
if !startswith(strip(snippets[1]), "function")
# we place extra line to avoid doc string error
println(io)
end
for s in snippets
println(io, s)
end
end
end
end
end
end