benchmark_transform/
main.rs

1/// This tool reads the generated weight.rs file that includes child trie access and replaces
2/// these DB Read and Writes with child trie specific DB weights located at rocksdb_child_trie_weights.rs
3use regex::Regex;
4use std::{
5	env, fs,
6	io::{self},
7	process,
8};
9
10fn main() -> io::Result<()> {
11	// Get command line arguments
12	let args: Vec<String> = env::args().collect();
13
14	if args.len() != 3 {
15		eprintln!("Usage: {} <input_weights_file> <output_weights_file>", args[0]);
16		eprintln!("Example: {} weights.rs transformed_weights.rs", args[0]);
17		process::exit(1);
18	}
19
20	let input_file = &args[1];
21	let output_file = &args[2];
22
23	// Read the input file
24	let input = fs::read_to_string(input_file).map_err(|e| {
25		eprintln!("Error reading file '{input_file}': {e}");
26		e
27	})?;
28
29	// Check if already modified
30	if is_already_modified(&input) {
31		println!("✓ It was already modified and no need for transformation!");
32		return Ok(())
33	}
34
35	// Transform the content
36	let output = transform_weights(&input);
37
38	// Write to output file
39	fs::write(output_file, &output).map_err(|e| {
40		eprintln!("Error writing to file '{output_file}': {e}");
41		e
42	})?;
43
44	println!("✓ Transformation complete!");
45	println!("  Input:  {input_file}");
46	println!("  Output: {output_file}");
47	Ok(())
48}
49
50fn is_already_modified(content: &str) -> bool {
51	content.contains("'benchmark_transform'")
52}
53
54fn transform_weights(content: &str) -> String {
55	let mut result = content.to_string();
56
57	// Step 1: Adds the modification disclaimer
58	result = ["\n//! MODIFIED by 'benchmark_transform' tool to replace child trie storage access with their specific DB weights", &result].join("\n//!");
59
60	// Step 2: Add RocksDbWeightChild import after the PhantomData import
61	let import_pattern = "use core::marker::PhantomData;";
62	let new_import = "use core::marker::PhantomData;\nuse common_runtime::weights::rocksdb_child_trie_weights::constants::RocksDbWeightChild;";
63	result = result.replace(import_pattern, new_import);
64
65	// Step 3: Process each function in SubstrateWeight<T> and the () implementation
66	result = process_weight_implementations(&result);
67
68	result
69}
70
71fn process_weight_implementations(content: &str) -> String {
72	let mut result = String::new();
73	let lines: Vec<&str> = content.lines().collect();
74	let mut i = 0;
75
76	while i < lines.len() {
77		let line = lines[i];
78		result.push_str(line);
79		result.push('\n');
80
81		// Check if we're at a function definition
82		if line.trim().starts_with("fn ") && line.contains("-> Weight {") {
83			// Collect function body lines and storage comments
84			let mut func_lines = vec![];
85			let mut storage_lines = Vec::new();
86
87			// Look backwards for Storage comments
88			let mut j = i;
89			while j > 0 {
90				j -= 1;
91				let prev_line = lines[j];
92				if prev_line.trim().starts_with("/// Storage:") {
93					storage_lines.insert(0, prev_line);
94				} else if !prev_line.trim().starts_with("///") {
95					break;
96				}
97			}
98
99			// Collect function body
100			let mut brace_count = 0;
101			let mut in_function = false;
102
103			while i < lines.len() {
104				let current = lines[i];
105				func_lines.push(current);
106
107				if current.contains('{') {
108					in_function = true;
109					brace_count += current.matches('{').count();
110				}
111				if current.contains('}') {
112					brace_count -= current.matches('}').count();
113				}
114
115				if in_function && brace_count == 0 {
116					break;
117				}
118				i += 1;
119			}
120
121			// Count UNKNOWN keys
122			let unknown_reads = count_unknown_keys(&storage_lines, "r:");
123			let unknown_writes = count_unknown_keys(&storage_lines, "w:");
124
125			if unknown_reads > 0 || unknown_writes > 0 {
126				// Process the function body
127				let processed =
128					process_function_body(&func_lines.join("\n"), unknown_reads, unknown_writes);
129				// Remove the original function line since we already added it
130				let processed_lines: Vec<&str> = processed.lines().collect();
131				// Skip first line as it's already added
132				for pline in &processed_lines[1..] {
133					result.push_str(pline);
134					result.push('\n');
135				}
136				i += 1;
137				continue;
138			}
139		}
140
141		i += 1;
142	}
143
144	result
145}
146
147fn count_unknown_keys(storage_lines: &[&str], operation: &str) -> u64 {
148	let mut count = 0;
149	for line in storage_lines {
150		if line.contains("UNKNOWN KEY") && line.contains(operation) {
151			// Extract the operation count (e.g., "r:1" or "w:1")
152			if let Some(idx) = line.find(operation) {
153				let after = &line[idx + operation.len()..];
154				if let Some(space_idx) = after.find(|c: char| c.is_whitespace() || c == ')') {
155					if let Ok(num) = after[..space_idx].parse::<u64>() {
156						count += num;
157					}
158				}
159			}
160		}
161	}
162	count
163}
164
165fn process_function_body(func_body: &str, unknown_reads: u64, unknown_writes: u64) -> String {
166	let lines: Vec<&str> = func_body.lines().collect();
167	let mut result = Vec::new();
168	let mut i = 0;
169	let read_regex = Regex::new(r"reads\((\d+)_u64\)").expect("Should create regex.");
170	let write_regex = Regex::new(r"writes\((\d+)_u64\)").expect("Should create regex.");
171
172	while i < lines.len() {
173		let line = lines[i];
174
175		// Check for T::DbWeight::get().reads() or RocksDbWeight::get().reads()
176		if (line.contains("T::DbWeight::get().reads(") ||
177			line.contains("RocksDbWeight::get().reads(")) &&
178			unknown_reads > 0
179		{
180			// Extract the current read count
181			if let Some(caps) = read_regex.captures(line) {
182				if let Ok(current_reads) = caps[1].parse::<u64>() {
183					// Only subtract if current_reads is greater than unknown_reads
184					if current_reads >= unknown_reads {
185						let new_reads = current_reads - unknown_reads;
186						let indent = get_indent(line);
187
188						// Determine which DbWeight to use
189						let db_weight = if line.contains("T::DbWeight") {
190							"T::DbWeight"
191						} else {
192							"RocksDbWeight"
193						};
194
195						// Add the modified line
196						if new_reads > 0 {
197							result.push(format!(
198								"{indent}.saturating_add({db_weight}::get().reads({new_reads}_u64))"
199							));
200						}
201						// Add the child trie reads line
202						result.push(format!(
203							"{indent}.saturating_add(RocksDbWeightChild::get().reads({unknown_reads}_u64))"
204						));
205						i += 1;
206						continue;
207					}
208				}
209			}
210		}
211
212		// Check for T::DbWeight::get().writes() or RocksDbWeight::get().writes()
213		if (line.contains("T::DbWeight::get().writes(") ||
214			line.contains("RocksDbWeight::get().writes(")) &&
215			unknown_writes > 0
216		{
217			if let Some(caps) = write_regex.captures(line) {
218				if let Ok(current_writes) = caps[1].parse::<u64>() {
219					// Only subtract if current_writes is greater than unknown_writes
220					if current_writes >= unknown_writes {
221						let new_writes = current_writes - unknown_writes;
222						let indent = get_indent(line);
223
224						let db_weight = if line.contains("T::DbWeight") {
225							"T::DbWeight"
226						} else {
227							"RocksDbWeight"
228						};
229
230						if new_writes > 0 {
231							result.push(format!(
232								"{indent}.saturating_add({db_weight}::get().writes({new_writes}_u64))"
233							));
234						}
235						result.push(format!(
236							"{indent}.saturating_add(RocksDbWeightChild::get().writes({unknown_writes}_u64))"
237						));
238						i += 1;
239						continue;
240					}
241				}
242			}
243		}
244
245		result.push(line.to_string());
246		i += 1;
247	}
248
249	result.join("\n")
250}
251
252fn get_indent(line: &str) -> String {
253	let trimmed = line.trim_start();
254	line[..line.len() - trimmed.len()].to_string()
255}